csukuangfj commited on
Commit
11ce2d8
·
1 Parent(s): bc4ec38

add 1.23.1

Browse files
v1.23.1/headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
v1.23.1/headers/nnapi_provider_factory.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // NNAPIFlags are bool options we want to set for NNAPI EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
10
+ // uint32_t nnapi_flags = 0;
11
+ // nnapi_flags |= NNAPI_FLAG_USE_FP16;
12
+ enum NNAPIFlags {
13
+ NNAPI_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
16
+ NNAPI_FLAG_USE_FP16 = 0x001,
17
+
18
+ // Use NCHW layout in NNAPI EP, this is only available after Android API level 29
19
+ // Please note for now, NNAPI perform worse using NCHW compare to using NHWC
20
+ NNAPI_FLAG_USE_NCHW = 0x002,
21
+
22
+ // Prevent NNAPI from using CPU devices.
23
+ //
24
+ // NNAPI is more efficient using GPU or NPU for execution, and NNAPI might fall back to its own CPU implementation
25
+ // for operations not supported by GPU/NPU. The CPU implementation of NNAPI (which is called nnapi-reference)
26
+ // might be less efficient than the optimized versions of the operation of ORT. It might be advantageous to disable
27
+ // the NNAPI CPU fallback and handle execution using ORT kernels.
28
+ //
29
+ // For some models, if NNAPI would use CPU to execute an operation, and this flag is set, the execution of the
30
+ // model may fall back to ORT kernels.
31
+ //
32
+ // This option is only available after Android API level 29, and will be ignored for Android API level 28-
33
+ //
34
+ // For NNAPI device assignments, see https://developer.android.com/ndk/guides/neuralnetworks#device-assignment
35
+ // For NNAPI CPU fallback, see https://developer.android.com/ndk/guides/neuralnetworks#cpu-fallback
36
+ //
37
+ // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
38
+ // and NNAPI_FLAG_CPU_ONLY flags are set
39
+ NNAPI_FLAG_CPU_DISABLED = 0x004,
40
+
41
+ // Using CPU only in NNAPI EP, this may decrease the perf but will provide
42
+ // reference output value without precision loss, which is useful for validation
43
+ //
44
+ // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
45
+ // and NNAPI_FLAG_CPU_ONLY flags are set
46
+ NNAPI_FLAG_CPU_ONLY = 0x008,
47
+
48
+ // Keep NNAPI_FLAG_LAST at the end of the enum definition
49
+ // And assign the last NNAPIFlag to it
50
+ NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY,
51
+ };
52
+
53
+ #ifdef __cplusplus
54
+ extern "C" {
55
+ #endif
56
+
57
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi,
58
+ _In_ OrtSessionOptions* options, uint32_t nnapi_flags);
59
+
60
+ #ifdef __cplusplus
61
+ }
62
+ #endif
v1.23.1/headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
v1.23.1/headers/onnxruntime_cxx_api.h ADDED
The diff for this file is too large to render. See raw diff
 
v1.23.1/headers/onnxruntime_cxx_inline.h ADDED
The diff for this file is too large to render. See raw diff
 
v1.23.1/headers/onnxruntime_ep_c_api.h ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_c_api.h" instead.
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ ORT_RUNTIME_CLASS(Ep);
11
+ ORT_RUNTIME_CLASS(EpFactory);
12
+ ORT_RUNTIME_CLASS(EpGraphSupportInfo);
13
+ ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice
14
+ ORT_RUNTIME_CLASS(NodeComputeContext);
15
+
16
+ ORT_RUNTIME_CLASS(DataTransferImpl);
17
+ ORT_RUNTIME_CLASS(SyncNotificationImpl);
18
+ ORT_RUNTIME_CLASS(SyncStreamImpl);
19
+
20
+ // struct that an EP implements for IDataTransfer to copy between devices it uses and CPU
21
+ struct OrtDataTransferImpl {
22
+ uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
23
+
24
+ /** \brief Release the OrtDataTransferImpl instance.
25
+ *
26
+ * This is called by ORT when the OrtDataTransferImpl instance is no longer needed.
27
+ * The implementation should release any resources held by the instance.
28
+ *
29
+ * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
30
+ *
31
+ * \since Version 1.23.
32
+ */
33
+ ORT_API_T(void, Release, _In_ OrtDataTransferImpl* this_ptr);
34
+
35
+ /** \brief Check if the implementation can copy between the source and destination memory devices.
36
+ *
37
+ * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
38
+ * \param[in] src_memory_device Source OrtMemoryDevice to copy from.
39
+ * \param[in] dst_memory_device Destination OrtMemoryDevice to copy to.
40
+ * \return True if the implementation can copy between the devices.
41
+ *
42
+ * \since Version 1.23.
43
+ */
44
+ ORT_API_T(bool, CanCopy, _In_ const OrtDataTransferImpl* this_ptr,
45
+ _In_ const OrtMemoryDevice* src_memory_device, _In_ const OrtMemoryDevice* dst_memory_device);
46
+
47
+ /** \brief Copy tensors from src_tensors to dst_tensors using the provided streams.
48
+ *
49
+ * The implementation can use the provided streams to perform asynchronous copies if supported.
50
+ * If a stream is not available, the copy is performed synchronously.
51
+ *
52
+ * \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
53
+ * \param[in] src_tensors Array of source OrtValue pointers to copy from.
54
+ * \param[in] dst_tensors Array of destination OrtValue pointers to copy to.
55
+ * \param[in] streams Array of OrtSyncStream pointers for the copy operations, if the execution provider is stream
56
+ * aware. nullptr if it is not.
57
+ * \param[in] num_tensors Number of tensors to copy.
58
+ *
59
+ * \snippet{doc} snippets.dox OrtStatus Return Value
60
+ *
61
+ * \since Version 1.23.
62
+ */
63
+ ORT_API2_STATUS(CopyTensors, _In_ OrtDataTransferImpl* this_ptr,
64
+ _In_reads_(num_tensors) const OrtValue** src_tensors,
65
+ _In_reads_(num_tensors) OrtValue** dst_tensors,
66
+ _In_reads_(num_tensors) OrtSyncStream** streams,
67
+ _In_ size_t num_tensors);
68
+ };
69
+
70
+ /** \brief Struct that an EP implements for Stream Notifications.
71
+ *
72
+ * \since Version 1.23.
73
+ */
74
+ struct OrtSyncNotificationImpl {
75
+ uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
76
+
77
+ /** \brief Release the OrtSyncNotificationImpl instance.
78
+ *
79
+ * This is called by ORT when the OrtSyncNotificationImpl instance is no longer needed.
80
+ * The implementation should release any resources held by the instance.
81
+ *
82
+ * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
83
+ *
84
+ * \since Version 1.23.
85
+ */
86
+ ORT_API_T(void, Release, _In_ OrtSyncNotificationImpl* this_ptr);
87
+
88
+ /** \brief Called by ORT to activate the notification.
89
+ *
90
+ * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
91
+ *
92
+ * \since Version 1.23.
93
+ */
94
+ ORT_API2_STATUS(Activate, _In_ OrtSyncNotificationImpl* this_ptr);
95
+
96
+ /** \brief Wait for a device to device operation to complete.
97
+ *
98
+ * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
99
+ * \param[in] stream The OrtSyncStream instance that will wait on this notification to be activated.
100
+ *
101
+ * \since Version 1.23.
102
+ */
103
+ ORT_API2_STATUS(WaitOnDevice, _In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* consumer_stream);
104
+
105
+ /** \brief Wait for a device to host operation to complete.
106
+ *
107
+ * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
108
+ *
109
+ * \since Version 1.23.
110
+ */
111
+ ORT_API2_STATUS(WaitOnHost, _In_ OrtSyncNotificationImpl* this_ptr);
112
+ };
113
+
114
+ /** \brief Struct that an EP implements if it wishes to implement Stream support.
115
+ *
116
+ * This struct provides the overrides for onnxruntime::Stream's virtual methods.
117
+ *
118
+ * \since Version 1.23.
119
+ */
120
+ struct OrtSyncStreamImpl {
121
+ uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
122
+
123
+ /** \brief Release the OrtSyncStreamImpl instance.
124
+ *
125
+ * This is called by ORT when the OrtSyncStreamImpl instance is no longer needed.
126
+ * The implementation should release any resources held by the instance.
127
+ *
128
+ * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
129
+ *
130
+ * \since Version 1.23.
131
+ */
132
+ ORT_API_T(void, Release, _In_ OrtSyncStreamImpl* this_ptr);
133
+
134
+ /** \brief Get the handle of the stream.
135
+ *
136
+ * This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams.
137
+ *
138
+ * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
139
+ * \return The handle of the stream.
140
+ *
141
+ * \since Version 1.23.
142
+ */
143
+ ORT_API_T(void*, GetHandle, _In_ OrtSyncStreamImpl* this_ptr);
144
+
145
+ /** \brief Create an OrtSyncNotificationImpl for the OrtSyncStreamImpl instance.
146
+ *
147
+ * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance
148
+ * \param[out] notification The new OrtSyncNotificationImpl instance.
149
+ *
150
+ * \since Version 1.23.
151
+ */
152
+ ORT_API2_STATUS(CreateNotification, _In_ OrtSyncStreamImpl* this_ptr,
153
+ _Outptr_ OrtSyncNotificationImpl** notification);
154
+
155
+ /** \brief Flush the stream.
156
+ *
157
+ * This is called by ORT to flush the stream, ensuring that all operations submitted to the stream are completed.
158
+ *
159
+ * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
160
+ *
161
+ * \since Version 1.23.
162
+ */
163
+ ORT_API2_STATUS(Flush, _In_ OrtSyncStreamImpl* this_ptr);
164
+
165
+ /** \brief Notify the stream that a session run has ended.
166
+ *
167
+ * This is called by ORT to notify the stream that a session run has ended, allowing the stream to perform any
168
+ * necessary cleanup or finalization.
169
+ *
170
+ * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
171
+ *
172
+ * \since Version 1.23.
173
+ */
174
+ ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr);
175
+ };
176
+
177
+ struct OrtNodeFusionOptions;
178
+ typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;
179
+
180
+ struct OrtNodeComputeInfo;
181
+ typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;
182
+
183
+ /**
184
+ * \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider.
185
+ *
186
+ * Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse.
187
+ *
188
+ * \since Version 1.23.
189
+ */
190
+ struct OrtNodeFusionOptions {
191
+ /** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with.
192
+ *
193
+ * Implementation should set to ORT_API_VERSION.
194
+ * ORT will use this to ensure it does not use members that were not available when the EP library was compiled.
195
+ *
196
+ * \since Version 1.23.
197
+ */
198
+ uint32_t ort_version_supported;
199
+
200
+ /** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant
201
+ * initializers as inputs to the fused node during model inference. This is used when the execution
202
+ * provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that
203
+ * are not used by any execution provider.
204
+ *
205
+ * If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to
206
+ * the fused node by default.
207
+ *
208
+ * \since Version 1.23.
209
+ */
210
+ bool drop_constant_initializers;
211
+
212
+ // const OrtNode* fused_node_schema;
213
+ };
214
+
215
+ /**
216
+ * \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute
217
+ * function for a compiled OrtGraph instance.
218
+ * \since Version 1.23.
219
+ */
220
+ struct OrtNodeComputeInfo {
221
+ /** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with.
222
+ *
223
+ * Implementation should set to ORT_API_VERSION.
224
+ * ORT will use this to ensure it does not call functions that were not available when the EP library was compiled.
225
+ *
226
+ * \since Version 1.23.
227
+ */
228
+ uint32_t ort_version_supported;
229
+
230
+ /** \brief Creates an opaque compute state object that is then passed to the Compute() function during inference.
231
+ * \param[in] this_ptr The OrtNodeComputeInfo instance.
232
+ * \param[in] compute_context OrtNodeComputeContext instance that contains compiled/fused node's name and host
233
+ * memory allocation functions. Can optionally be used to build the compute state.
234
+ * \param[out] compute_state Output parameter that is assigned the opaque computation state. ONNX Runtime calls
235
+ * ReleaseState() (after calling Compute()) to allow the implementer to release the
236
+ * compute state.
237
+ *
238
+ * \snippet{doc} snippets.dox OrtStatus Return Value
239
+ *
240
+ * \since Version 1.23.
241
+ */
242
+ OrtStatus*(ORT_API_CALL* CreateState)(_In_ OrtNodeComputeInfo* this_ptr,
243
+ _In_ OrtNodeComputeContext* compute_context,
244
+ _Outptr_ void** compute_state);
245
+
246
+ /** \brief Computation function called to execute the fused node compiled by an OrtEp instance.
247
+ * \param[in] this_ptr The OrtNodeComputeInfo instance.
248
+ * \param[in] compute_state The opaque computation state returned by CreateState().
249
+ * \param[in] kernel_context The OrtKernelContext instance used to access inputs/outputs.
250
+ *
251
+ * \snippet{doc} snippets.dox OrtStatus Return Value
252
+ *
253
+ * \since Version 1.23.
254
+ */
255
+ OrtStatus*(ORT_API_CALL* Compute)(_In_ OrtNodeComputeInfo* this_ptr, _In_ void* compute_state,
256
+ _In_ OrtKernelContext* kernel_context);
257
+
258
+ /** \brief Releases the compute state returned by CreateState().
259
+ * \param[in] this_ptr The OrtNodeComputeInfo instance.
260
+ * \param[inout] compute_state The opaque compute state returned by CreateState().
261
+ *
262
+ * \since Version 1.23.
263
+ */
264
+ void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state);
265
+ };
266
+
267
+ struct OrtEpApi {
268
+ /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice.
269
+ * \param[in] ep_factory Execution provider factory that is creating the instance.
270
+ * \param[in] hardware_device Hardware device that the EP can utilize.
271
+ * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
272
+ * during execution provider selection and passed to CreateEp.
273
+ * ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
274
+ * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
275
+ * to the Session configuration options if the execution provider is selected.
276
+ * ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
277
+ * \param ep_device OrtExecutionDevice that is created.
278
+ *
279
+ * \since Version 1.22.
280
+ */
281
+ ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
282
+ _In_ const OrtHardwareDevice* hardware_device,
283
+ _In_opt_ const OrtKeyValuePairs* ep_metadata,
284
+ _In_opt_ const OrtKeyValuePairs* ep_options,
285
+ _Out_ OrtEpDevice** ep_device);
286
+
287
+ ORT_CLASS_RELEASE(EpDevice);
288
+
289
+ /** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
290
+ *
291
+ * Because the nodes will be fused into one "fused node", there must not exist an unsupported node in
292
+ * a path between two of the provided nodes. Otherwise, the graph will become invalid.
293
+ *
294
+ * This function can be called multiple times. A subsequent call to this function will force the next set of
295
+ * nodes to be fused into a different node.
296
+ *
297
+ * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes.
298
+ * \param[in] nodes Array of nodes supported by the EP that should be fused/compiled.
299
+ * \param[in] num_nodes The number of supported nodes.
300
+ * \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL.
301
+ *
302
+ * \snippet{doc} snippets.dox OrtStatus Return Value
303
+ *
304
+ * \since Version 1.23.
305
+ */
306
+ ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
307
+ _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
308
+ _In_opt_ const OrtNodeFusionOptions* node_fusion_options);
309
+
310
+ /** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel.
311
+ *
312
+ * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported node.
313
+ * \param[in] node The supported OrtNode instance.
314
+ *
315
+ * \snippet{doc} snippets.dox OrtStatus Return Value
316
+ *
317
+ * \since Version 1.23.
318
+ */
319
+ ORT_API2_STATUS(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info,
320
+ _In_ const OrtNode* node);
321
+
322
+ /** \brief Query a OrtNodeComputeContext for the name of the node that encapsulates the compiled/fused node.
323
+ *
324
+ * Used in OrtNodeComputeInfo::CreateComputeState().
325
+ *
326
+ * \param[in] context The OrtNodeComputeContext instance to query.
327
+ * \return The node's name.
328
+ *
329
+ * \note Returned string is owned by ORT and valid only while OrtNodeComputeInfo::CreateComputeState() is called.
330
+ *
331
+ * \since Version 1.23.
332
+ */
333
+ ORT_API_T(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);
334
+
335
+ /** \brief Register an allocator with the OrtEpDevice.
336
+ *
337
+ * This allows an EP to provide OrtMemoryInfo for DEFAULT and HOST_ACCESSIBLE memory type as needed.
338
+ * The registered values will be used in calls to OrtEpFactory::CreateAllocator to ensure the required allocator/s
339
+ * are available for EP usage.
340
+ *
341
+ * Multiple calls for the same entry type will replace a previous entry.
342
+ *
343
+ * Available entries:
344
+ * - OrtDeviceAllocator with type of OrtDeviceMemoryType_DEFAULT
345
+ * - OrtDeviceAllocator with type of OrtDeviceMemoryType_HOST_ACCESSIBLE
346
+ * - OrtReadOnlyAllocator with type of OrtDeviceMemoryType_DEFAULT
347
+ * - if provided this allocator will only be used to copy initializers to the device the EP uses.
348
+ * ORT will use the OrtDeviceAllocator if not provided.
349
+ *
350
+ * \param[in] ep_device The OrtEpDevice instance to register the OrtMemoryInfo with.
351
+ * \param[in] allocator_memory_info The OrtMemoryInfo information for the allocator.
352
+ *
353
+ * \snippet{doc} snippets.dox OrtStatus Return Value
354
+ *
355
+ * \since Version 1.23.
356
+ */
357
+ ORT_API2_STATUS(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device,
358
+ _In_ const OrtMemoryInfo* allocator_memory_info);
359
+
360
+ /** \brief Get the OrtMemoryDevice from an OrtMemoryInfo instance.
361
+ *
362
+ * This is required for OrtDataTransferImpl (which implements onnxruntime::IDataTransfer) where the OrtMemoryDevice
363
+ * is used in the CanCopy and CopyTensors functions.
364
+ *
365
+ * \param[in] memory_info The OrtMemoryInfo instance to get the memory device from.
366
+ * \return The OrtMemoryDevice associated with the OrtMemoryInfo instance.
367
+ *
368
+ * \since Version 1.23.
369
+ */
370
+ ORT_API_T(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info);
371
+
372
+ /** \brief Get the OrtMemoryDevice from an OrtValue instance if it contains a Tensor.
373
+ *
374
+ * \param[in] value The OrtValue instance to get the memory device from.
375
+ * \return Memory device if OrtValue contains a Tensor, nullptr otherwise.
376
+ *
377
+ * \since Version 1.23.
378
+ */
379
+ ORT_API_T(const OrtMemoryDevice*, Value_GetMemoryDevice, _In_ const OrtValue* value);
380
+
381
+ /** \brief Compare two OrtMemoryDevice instances for equality.
382
+ *
383
+ * This is used to check if two memory devices are the same.
384
+ * Used to implement DataTransferImpl::CanCopy.
385
+ *
386
+ * \param[in] a The first OrtMemoryDevice instance to compare.
387
+ * \param[in] b The second OrtMemoryDevice instance to compare.
388
+ * \return True if the two OrtMemoryDevice instances are equal, false otherwise.
389
+ *
390
+ * \since Version 1.23.
391
+ */
392
+ ORT_API_T(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b);
393
+
394
+ /** \brief Get the OrtMemoryInfoDeviceType value from an OrtMemoryDevice instance.
395
+ *
396
+ * \param[in] memory_device OrtMemoryDevice instance.
397
+ * \return The OrtMemoryInfoDeviceType value.
398
+ *
399
+ * \since Version 1.23.
400
+ */
401
+ ORT_API_T(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device);
402
+
403
+ /** \brief Get the OrtDeviceMemoryType value from an OrtMemoryDevice instance.
404
+ *
405
+ * \param[in] memory_device OrtMemoryDevice instance.
406
+ * \return The OrtDeviceMemoryType value.
407
+ *
408
+ * \since Version 1.23.
409
+ */
410
+ ORT_API_T(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device);
411
+
412
+ /** \brief Get the vendor ID from an OrtMemoryDevice instance.
413
+ *
414
+ * The vendor ID is used to identify the vendor of the device, and is typically set to the PCI vendor ID.
415
+ *
416
+ * If the device is not vendor specific (e.g. CPU memory) the vendor ID is set to 0.
417
+ *
418
+ * \param[in] memory_device OrtMemoryDevice instance.
419
+ * \return The vendor ID value.
420
+ *
421
+ * \since Version 1.23.
422
+ */
423
+ ORT_API_T(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device);
424
+
425
+ /** \brief Get the device ID from an OrtMemoryDevice instance.
426
+ *
427
+ * \param[in] memory_device OrtMemoryDevice instance.
428
+ * \return The device ID.
429
+ *
430
+ * \since Version 1.23.
431
+ */
432
+ ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device);
433
+
434
+ /** \brief Get the OrtSyncStreamImpl associated with an OrtSyncStream instance.
435
+ *
436
+ * This allows an the plugin library to connect its OrtSyncStreamImpl instance with an OrtSyncStream if needed.
437
+ *
438
+ * \param[in] stream The OrtSyncStream instance to find an OrtSyncStreamImpl for.
439
+ * \return The associated OrtSyncStreamImpl if found. nullptr otherwise.
440
+ *
441
+ * \since Version 1.23.
442
+ *
443
+ * \remarks There should always be an OrtSyncStreamImpl associated with an OrtSyncStream instance that the EP gets.
444
+ */
445
+ ORT_API_T(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream);
446
+
447
+ /** \brief Get the current sync ID for a stream.
448
+ *
449
+ * \param[in] stream The OrtSyncStream to get the sync ID for.
450
+ * \return Current sync ID.
451
+ *
452
+ * \since Version 1.23.
453
+ */
454
+ ORT_API_T(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream);
455
+
456
+ /** \brief Get the sync ID for the last time the consumer_stream waited on the producer_stream.
457
+ *
458
+ * When two streams are synchronized, the sync id represents the event used in that synchronization.
459
+ *
460
+ * \param[in] producer_stream The OrtSyncStream that produced the data.
461
+ * \param[in] consumer_stream The OrtSyncStream that waited on the producer_stream.
462
+ * \return ID for last sync. 0 if no sync has occurred between the two streams.
463
+ *
464
+ * \since Version 1.23.
465
+ */
466
+ ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream,
467
+ _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream);
468
+ };
469
+
470
+ /**
471
+ * \brief The data layout type.
472
+ *
473
+ * EPs may specify a preferred data layout type. ORT's default layout type is OrtEpDataLayout_NCHW, or
474
+ * OrtEpDataLayout_Default.
475
+ *
476
+ * \since Version 1.23.
477
+ */
478
+ typedef enum OrtEpDataLayout {
479
+ OrtEpDataLayout_NCHW = 0,
480
+ OrtEpDataLayout_NHWC,
481
+
482
+ OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
483
+ } OrtEpDataLayout;
484
+
485
+ /**
486
+ * \brief The OrtEp struct provides functions to implement for an execution provider.
487
+ * \since Version 1.22.
488
+ */
489
+ struct OrtEp {
490
+ /** \brief The ONNX Runtime version the execution provider was compiled with.
491
+ *
492
+ * Implementation should set to ORT_API_VERSION.
493
+ * ORT will use this to ensure it does not call functions that were not available when the library was compiled.
494
+ *
495
+ * \since Version 1.22.
496
+ */
497
+ uint32_t ort_version_supported;
498
+
499
+ /** \brief Get the execution provider name.
500
+ *
501
+ * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
502
+ *
503
+ * \param[in] this_ptr The OrtEp instance.
504
+ * \return The execution provider name.
505
+ *
506
+ * \since Version 1.22.
507
+ */
508
+ ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr);
509
+
510
+ /** \brief Get information about the nodes supported by the OrtEp instance.
511
+ *
512
+ * IMPORTANT: This is not the final version of this API function. This is currently experimental but will
513
+ * be stabilized by the ONNX Runtime 1.23 release.
514
+ *
515
+ * \param[in] this_ptr The OrtEp instance.
516
+ * \param[in] graph The OrtGraph instance for which to populate node support. The OrtGraph could be a nested subgraph
517
+ * contained by a node (e.g., an If or Loop node). ONNX Runtime calls this function separately
518
+ * for each nested subgraph.
519
+ * \param[inout] graph_support_info OrtEpGraphSupportInfo instance that the implementer must fill out in order to
520
+ * specify the supported nodes.
521
+ *
522
+ * \snippet{doc} snippets.dox OrtStatus Return Value
523
+ *
524
+ * \since Version 1.23.
525
+ */
526
+ ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph,
527
+ _Inout_ OrtEpGraphSupportInfo* graph_support_info);
528
+
529
+ /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance
530
+ * for each OrtGraph in order to define its computation function.
531
+ *
532
+ * If the session is configured to generate a pre-compiled model, the execution provider must return EPContext nodes,
533
+ * as OrtNode instances, that ONNX Runtime uses to create a pre-compiled model, known as an "EPContext model".
534
+ * An EPContext model contains EPContext nodes. Each EPContext node encapsulates the pre-compiled binary data for a
535
+ * OrtGraph compiled for a specific execution provider. For more details about the EPContext design, refer to:
536
+ * \htmlonly
537
+ * <a href="https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html">EPContext design document.</a>
538
+ * \endhtmlonly
539
+ *
540
+ * \param[in] this_ptr The OrtEp instance.
541
+ * \param[in] graphs Array of `count` OrtGraph instances to compile. Each graph contains only the nodes for
542
+ * which the execution provider indicated support. Nested subgraphs contained by a
543
+ * node, such as an If or Loop, have separate OrtGraph instances.
544
+ * \param[in] fused_nodes Array of `count` fused nodes that will replace the compiled graphs.
545
+ * Each fused node is an OrtNode initialized with the intended fused node name and
546
+ * input/output information.
547
+ * \param[in] count The number of OrtGraph instances to compile.
548
+ * \param[out] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
549
+ * computation function. The implementer allocates the OrtNodeComputeInfo instances.
550
+ * ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
551
+ * \param[out] ep_context_nodes Output array of `count` OrtNode instances, each representing an EPContext
552
+ * node for a compiled OrtGraph. The execution provider must use
553
+ * OrtModelEditorApi::CreateNode to create the OrtNode instances. ONNX Runtime takes
554
+ * ownership of the OrtNode instances, so the execution provider must NOT call
555
+ * OrtApi::ReleaseNode. Should be ignored if the session is not configured to generate an
556
+ * EPContext model.
557
+ *
558
+ * \snippet{doc} snippets.dox OrtStatus Return Value
559
+ *
560
+ * \note Do NOT cache the provided OrtGraph instances in any of the OrtNodeComputeInfo functions because the
561
+ * graphs are only valid for the duration of the call to Compile. Any graph/node/input/output
562
+ * names that are needed by the OrtNodeComputeInfo functions must be copied and stored by the OrtEp.
563
+ *
564
+ * \since Version 1.23.
565
+ */
566
+ ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
567
+ _In_ const OrtNode** fused_nodes, _In_ size_t count,
568
+ _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
569
+ _Out_writes_(count) OrtNode** ep_context_nodes);
570
+
571
+ /** \brief Release OrtNodeComputeInfo instances.
572
+ *
573
+ * \param[in] this_ptr The OrtEp instance.
574
+ * \param[inout] node_compute_infos The OrtNodeComputeInfo instances to release.
575
+ * \param[in] num_node_compute_infos The number of OrtNodeComputeInfo instances.
576
+ *
577
+ * \since Version 1.23.
578
+ */
579
+ ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr,
580
+ OrtNodeComputeInfo** node_compute_infos,
581
+ _In_ size_t num_node_compute_infos);
582
+
583
+ /** \brief Get the EP's preferred data layout.
584
+ *
585
+ * \note Implementation of this function is optional.
586
+ * If not implemented, ORT will assume that this EP prefers the data layout `OrtEpDataLayout::NCHW`.
587
+ *
588
+ * \param[in] this_ptr The OrtEp instance.
589
+ * \param[out] preferred_data_layout The EP's preferred data layout.
590
+ *
591
+ * \snippet{doc} snippets.dox OrtStatus Return Value
592
+ *
593
+ * \since Version 1.23.
594
+ */
595
+ ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout);
596
+
597
+ /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout
598
+ * should be converted to `target_data_layout`.
599
+ * If the EP prefers a non-default data layout (see `GetPreferredDataLayout()`), this function will be called
600
+ * during layout transformation with `target_data_layout` set to the EP's preferred data layout.
601
+ *
602
+ * \note Implementation of this function is optional.
603
+ * If an EP prefers a non-default data layout, it may implement this to customize the specific op data layout
604
+ * preferences at a finer granularity.
605
+ *
606
+ * \param[in] this_ptr The OrtEp instance.
607
+ * \param[in] domain The op domain. An empty string means the ONNX domain.
608
+ * \param[in] op_type The op type.
609
+ * \param[in] target_data_layout The target data layout.
610
+ * \param[out] should_convert Whether the associated node's data layout should be converted to `target_data_layout`.
611
+ * If greater than 0, convert.
612
+ * If 0, don't convert.
613
+ * Otherwise, if less than 0, leave the decision to ORT.
614
+ *
615
+ * \snippet{doc} snippets.dox OrtStatus Return Value
616
+ *
617
+ * \since Version 1.23.
618
+ */
619
+ ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr,
620
+ _In_z_ const char* domain, _In_z_ const char* op_type,
621
+ _In_ OrtEpDataLayout target_data_layout,
622
+ _Outptr_ int* should_convert);
623
+
624
+ /** \brief Set dynamic options on this EP.
625
+ *
626
+ * Dynamic options can be set by the user at any time after session creation with `OrtApi::SetEpDynamicOptions()`.
627
+ *
628
+ * \param[in] this_ptr The OrtEp instance.
629
+ * \param[in] option_keys The dynamic option keys.
630
+ * \param[in] option_values The dynamic option values.
631
+ * \param[in] num_options The number of dynamic options.
632
+ *
633
+ * \note Implementation of this function is optional.
634
+ * An EP should only implement this if it needs to handle any dynamic options.
635
+ *
636
+ * \snippet{doc} snippets.dox OrtStatus Return Value
637
+ *
638
+ * \since Version 1.23.
639
+ */
640
+ ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr,
641
+ _In_reads_(num_options) const char* const* option_keys,
642
+ _In_reads_(num_options) const char* const* option_values,
643
+ _In_ size_t num_options);
644
+
645
+ /** \brief Called by ORT to notify the EP of the start of a run.
646
+ *
647
+ * \param[in] this_ptr The OrtEp instance.
648
+ * \param[in] run_options The run options for this run.
649
+ *
650
+ * \note Implementation of this function is optional.
651
+ *
652
+ * \snippet{doc} snippets.dox OrtStatus Return Value
653
+ *
654
+ * \since Version 1.23.
655
+ */
656
+ ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options);
657
+
658
+ /** \brief Called by ORT to notify the EP of the end of a run.
659
+ *
660
+ * \param[in] this_ptr The OrtEp instance.
661
+ * \param[in] run_options The run options for this run.
662
+ * \param[in] sync_stream Whether any associated stream should be synchronized during this call.
663
+ * Only applicable if there is such a stream.
664
+ *
665
+ * \note Implementation of this function is optional.
666
+ *
667
+ * \snippet{doc} snippets.dox OrtStatus Return Value
668
+ *
669
+ * \since Version 1.23.
670
+ */
671
+ ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream);
672
+
673
+ /** \brief Create an OrtAllocator for the given OrtMemoryInfo for an OrtSession.
674
+ *
675
+ * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo.
676
+ * Any allocator specific options should be read from the session options.
677
+ *
678
+ * If nullptr OrtEpFactory::CreateAllocator will be used.
679
+ *
680
+ * \param[in] this_ptr The OrtEpFactory instance.
681
+ * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr.
682
+ * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used.
683
+ *
684
+ * \snippet{doc} snippets.dox OrtStatus Return Value
685
+ *
686
+ * \since Version 1.23.
687
+ */
688
+ ORT_API2_STATUS(CreateAllocator, _In_ OrtEp* this_ptr,
689
+ _In_ const OrtMemoryInfo* memory_info,
690
+ _Outptr_result_maybenull_ OrtAllocator** allocator);
691
+
692
+ /** \brief Create a synchronization stream for the given memory device for an OrtSession.
693
+ *
694
+ * This is used to create a synchronization stream for the execution provider and is used to synchronize
695
+ * operations on the device during model execution.
696
+ * Any stream specific options should be read from the session options.
697
+ *
698
+ * If nullptr OrtEpFactory::CreateSyncStreamForDevice will be used.
699
+ *
700
+ * \param[in] this_ptr The OrtEpFactory instance.
701
+ * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for.
702
+ * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware.
703
+ *
704
+ * \snippet{doc} snippets.dox OrtStatus Return Value
705
+ *
706
+ * \since Version 1.23.
707
+ */
708
+ ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEp* this_ptr,
709
+ _In_ const OrtMemoryDevice* memory_device,
710
+ _Outptr_ OrtSyncStreamImpl** stream);
711
+
712
+ /** \brief Get a string with details about the EP stack used to produce a compiled model.
713
+ *
714
+ * This function gets a compatibility information string that contains details about the execution provider
715
+ * used to compile a given model. This string can later be used with ValidateCompiledModelCompatibilityInfo
716
+ * to determine if a compiled model is compatible with the EP.
717
+ *
718
+ * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
719
+ *
720
+ * \param[in] this_ptr The OrtEp instance.
721
+ * \param[in] graph The OrtGraph instance for which to generate compatibility information.
722
+ *
723
+ * \snippet{doc} snippets.dox OrtStatus Return Value
724
+ *
725
+ * \since Version 1.23.
726
+ */
727
+ ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr,
728
+ _In_ const OrtGraph* graph);
729
+ };
730
+
731
+ /** \brief The function signature that ORT will call to create OrtEpFactory instances.
732
+ *
733
+ * This must be available in a function called 'CreateEpFactories' in the execution provider library.
734
+ *
735
+ * \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary
736
+ * \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the
737
+ * version of ORT that the library was compiled against.
738
+ * \param[in] default_logger The default ORT logger that can be used for logging outside of an inference session.
739
+ * \param[in,out] factories The implementation should create and add OrtEpFactory instances to this
740
+ * pre-allocated array.
741
+ * i.e. usage is `factories[0] = new MyEpFactory();`
742
+ * \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`.
743
+ * Current default is to allow 4 factories. This can be increased in the future if needed.
744
+ * \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`.
745
+ *
746
+ * \snippet{doc} snippets.dox OrtStatus Return Value
747
+ *
748
+ * \since Version 1.22.
749
+ */
750
+ typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base,
751
+ _In_ const OrtLogger* default_logger,
752
+ _Inout_ OrtEpFactory** factories, _In_ size_t max_factories,
753
+ _Out_ size_t* num_factories);
754
+
755
+ /** \brief The function signature that ORT will call to release an OrtEpFactory instance.
756
+ *
757
+ * This must be available in a function called 'ReleaseEpFactory' in the execution provider library.
758
+ *
759
+ * \param[in] factory The OrtEpFactory instance to release.
760
+ *
761
+ * \snippet{doc} snippets.dox OrtStatus Return Value
762
+ *
763
+ * \since Version 1.22.
764
+ */
765
+ typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory);
766
+
767
+ /**
768
+ * \brief The OrtEpFactory provides functions to create and manage execution providers.
769
+ * \since Version 1.22.
770
+ */
771
+ struct OrtEpFactory {
772
+ /** \brief The ONNX Runtime version the execution provider was compiled with.
773
+ *
774
+ * Implementation should set to ORT_API_VERSION.
775
+ * ORT will use this to ensure it does not call functions that were not available when the library was compiled.
776
+ *
777
+ * \since Version 1.22.
778
+ */
779
+ uint32_t ort_version_supported;
780
+
781
+ /** \brief Get the name of the execution provider that the factory creates.
782
+ *
783
+ * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
784
+ *
785
+ * \param[in] this_ptr The OrtEpFactory instance.
786
+ * \return The name of the execution provider the factory creates.
787
+ *
788
+ * \since Version 1.22.
789
+ */
790
+ ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr);
791
+
792
+ /** \brief Get the name of vendor who owns the execution provider that the factory creates.
793
+ *
794
+ * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
795
+ *
796
+ * \param[in] this_ptr The OrtEpFactory instance.
797
+ * \return vendor The vendor name of the execution provider the factory creates.
798
+ *
799
+ * \since Version 1.22.
800
+ */
801
+ ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor
802
+
803
+ /** \brief Get information from the execution provider about OrtHardwareDevice support.
804
+ *
805
+ * \param[in] this_ptr The OrtEpFactory instance.
806
+ * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice.
807
+ * \param[in] devices The OrtHardwareDevice instances that are available.
808
+ * \param[in] num_devices The number of OrtHardwareDevice instances.
809
+ * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use.
810
+ * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice
811
+ * instances to this pre-allocated array. ORT will take ownership of the values returned.
812
+ * i.e. usage is `ep_devices[0] = <ptr to OrtEpDevice created with OrtEpApi::CreateEpDevice>;`
813
+ * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices.
814
+ * Current default is 8. This can be increased if needed.
815
+ * \param[out] num_ep_devices The number of EP devices added to ep_devices.
816
+ * \return true if the factory can create an execution provider that uses `device`.
817
+ *
818
+ * \since Version 1.22.
819
+ */
820
+ ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr,
821
+ _In_reads_(num_devices) const OrtHardwareDevice* const* devices,
822
+ _In_ size_t num_devices,
823
+ _Inout_ OrtEpDevice** ep_devices,
824
+ _In_ size_t max_ep_devices,
825
+ _Out_ size_t* num_ep_devices);
826
+
827
+ /** \brief Function to create an OrtEp instance for use in a Session.
828
+ *
829
+ * ORT will call ReleaseEp to release the instance when it is no longer needed.
830
+ *
831
+ * \param[in] this_ptr The OrtEpFactory instance.
832
+ * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use.
833
+ * May be a subset of the OrtHardwareDevice instances that the execution provider's factory
834
+ * set as supported in the call to OrtEpFactory::GetSupportedDevices.
835
+ * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each
836
+ * device.
837
+ * \param[in] num_devices The number of devices the execution provider was selected for.
838
+ * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the
839
+ * session. This will include ep_options from GetSupportedDevices as well as any
840
+ * user provided overrides.
841
+ * Execution provider options will have been added with a prefix of 'ep.[ep name].'.
842
+ * The OrtSessionOptions instance will NOT be valid after this call and should not be
843
+ * stored for later use.
844
+ * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging.
845
+ * \param[out] ep The OrtEp instance created by the factory.
846
+ *
847
+ * \snippet{doc} snippets.dox OrtStatus Return Value
848
+ *
849
+ * \since Version 1.22.
850
+ */
851
+ ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr,
852
+ _In_reads_(num_devices) const OrtHardwareDevice* const* devices,
853
+ _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs,
854
+ _In_ size_t num_devices,
855
+ _In_ const OrtSessionOptions* session_options,
856
+ _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep);
857
+
858
+ /** \brief Release the OrtEp instance.
859
+ *
860
+ * \param[in] this_ptr The OrtEpFactory instance.
861
+ * \param[in] ep The OrtEp instance to release.
862
+ *
863
+ * \since Version 1.22.
864
+ */
865
+ ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep);
866
+
867
+ /** \brief Get the vendor id who owns the execution provider that the factory creates.
868
+ *
869
+ * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies
870
+ *
871
+ * \param[in] this_ptr The OrtEpFactory instance.
872
+ * \return vendor_id The vendor ID of the execution provider the factory creates.
873
+ *
874
+ * \since Version 1.23.
875
+ */
876
+ ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr);
877
+
878
+ /** \brief Get the version of the execution provider that the factory creates.
879
+ *
880
+ * The version string should adhere to the Semantic Versioning 2.0 specification
881
+ * (https://github.com/semver/semver/blob/v2.0.0/semver.md).
882
+ *
883
+ * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
884
+ *
885
+ * \param[in] this_ptr The OrtEpFactory instance.
886
+ * \return The execution provider version string.
887
+ *
888
+ * \since Version 1.23.
889
+ */
890
+ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr);
891
+
892
+ /** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices.
893
+ *
894
+ * Given a compatibility info string produced during model compilation, the EP factory should determine whether the
895
+ * compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided
896
+ * must belong to the same execution provider instance that this factory creates.
897
+ *
898
+ * The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when
899
+ * evaluating compatibility and set `model_compatibility` accordingly.
900
+ *
901
+ * \param[in] this_ptr The OrtEpFactory instance.
902
+ * \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP.
903
+ * \param[in] num_devices Number of entries in `devices`.
904
+ * \param[in] compatibility_info The compatibility information string produced when the model was compiled.
905
+ * \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP.
906
+ *
907
+ * \snippet{doc} snippets.dox OrtStatus Return Value
908
+ *
909
+ * \since Version 1.23.
910
+ */
911
+ ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr,
912
+ _In_reads_(num_devices) const OrtHardwareDevice* const* devices,
913
+ _In_ size_t num_devices,
914
+ _In_ const char* compatibility_info,
915
+ _Out_ OrtCompiledModelCompatibility* model_compatibility);
916
+
917
+ /** \brief Create an OrtAllocator that can be shared across sessions for the given OrtMemoryInfo.
918
+ *
919
+ * The factory that creates the EP is responsible for providing the allocators required by the EP.
920
+ * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo.
921
+ *
922
+ * \param[in] this_ptr The OrtEpFactory instance.
923
+ * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr.
924
+ * \param[in] allocator_options Optional key-value pairs for allocator options, can be nullptr.
925
+ * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used.
926
+ *
927
+ * \snippet{doc} snippets.dox OrtStatus Return Value
928
+ *
929
+ * \since Version 1.23.
930
+ */
931
+ ORT_API2_STATUS(CreateAllocator, _In_ OrtEpFactory* this_ptr,
932
+ _In_ const OrtMemoryInfo* memory_info,
933
+ _In_opt_ const OrtKeyValuePairs* allocator_options,
934
+ _Outptr_result_maybenull_ OrtAllocator** allocator);
935
+
936
+ /** \brief Release an OrtAllocator created by the factory.
937
+ *
938
+ * \since Version 1.23.
939
+ */
940
+ ORT_API_T(void, ReleaseAllocator, _In_ OrtEpFactory* this_ptr, _In_ OrtAllocator* allocator);
941
+
942
+ /** \brief Create an OrtDataTransferImpl instance for the factory.
943
+ *
944
+ * This is used to create an IDataTransfer implementation that can be used to copy data between devices
945
+ * that the execution provider supports.
946
+ *
947
+ * \param[in] this_ptr The OrtEpFactory instance.
948
+ * \param[out] data_transfer The created OrtDataTransferImpl instance. Set to nullptr if not required.
949
+ *
950
+ * \snippet{doc} snippets.dox OrtStatus Return Value
951
+ *
952
+ * \since Version 1.23.
953
+ */
954
+ ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr,
955
+ _Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer);
956
+
957
+ /** \brief Check if execution providers created by the factory are stream aware.
958
+ *
959
+ * \param[in] this_ptr The OrtEpFactory instance.
960
+ * \return True if the factory creates execution providers that are stream aware and it implements CreateSyncStreamForDevice.
961
+ *
962
+ * \since Version 1.23.
963
+ */
964
+ ORT_API_T(bool, IsStreamAware, _In_ const OrtEpFactory* this_ptr);
965
+
966
+ /** \brief Create a synchronization stream for the given memory device.
967
+ *
968
+ * This is used to create a synchronization stream for the memory device that can be used for operations outside of
969
+ * a session.
970
+ *
971
+ * \param[in] this_ptr The OrtEpFactory instance.
972
+ * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for.
973
+ * \param[in] stream_options Options for stream creation. May be nullptr.
974
+ * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware.
975
+ *
976
+ * \snippet{doc} snippets.dox OrtStatus Return Value
977
+ *
978
+ * \since Version 1.23.
979
+ */
980
+ ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEpFactory* this_ptr,
981
+ _In_ const OrtMemoryDevice* memory_device,
982
+ _In_opt_ const OrtKeyValuePairs* stream_options,
983
+ _Outptr_ OrtSyncStreamImpl** stream);
984
+ };
985
+
986
+ #ifdef __cplusplus
987
+ }
988
+ #endif
v1.23.1/headers/onnxruntime_ep_device_ep_metadata_keys.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ // This file contains well-known keys for OrtEpDevice EP metadata entries.
7
+ // It does NOT specify all available metadata keys.
8
+
9
+ // Key for the execution provider version string. This should be available for all plugin EPs.
10
+ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version";
11
+
12
+ // Prefix for execution provider compatibility information stored in model metadata.
13
+ // Used when generating EP context models to store compatibility strings for each EP.
14
+ // Full key format: "ep_compatibility_info.<EP_TYPE>"
15
+ static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info.";
16
+
17
+ // Key for the execution provider library path (for dynamically loaded EPs)
18
+ static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path";
v1.23.1/headers/onnxruntime_float16.h ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <stdint.h>
7
+ #include <cmath>
8
+ #include <cstring>
9
+ #include <limits>
10
+
11
+ namespace onnxruntime_float16 {
12
+
13
+ namespace detail {
14
+
15
+ enum class endian {
16
+ #if defined(_WIN32)
17
+ little = 0,
18
+ big = 1,
19
+ native = little,
20
+ #elif defined(__GNUC__) || defined(__clang__)
21
+ little = __ORDER_LITTLE_ENDIAN__,
22
+ big = __ORDER_BIG_ENDIAN__,
23
+ native = __BYTE_ORDER__,
24
+ #else
25
+ #error onnxruntime_float16::detail::endian is not implemented in this environment.
26
+ #endif
27
+ };
28
+
29
+ static_assert(
30
+ endian::native == endian::little || endian::native == endian::big,
31
+ "Only little-endian or big-endian native byte orders are supported.");
32
+
33
+ } // namespace detail
34
+
35
+ /// <summary>
36
+ /// Shared implementation between public and internal classes. CRTP pattern.
37
+ /// </summary>
38
+ template <class Derived>
39
+ struct Float16Impl {
40
+ protected:
41
+ /// <summary>
42
+ /// Converts from float to uint16_t float16 representation
43
+ /// </summary>
44
+ /// <param name="v"></param>
45
+ /// <returns></returns>
46
+ constexpr static uint16_t ToUint16Impl(float v) noexcept;
47
+
48
+ /// <summary>
49
+ /// Converts float16 to float
50
+ /// </summary>
51
+ /// <returns>float representation of float16 value</returns>
52
+ float ToFloatImpl() const noexcept;
53
+
54
+ /// <summary>
55
+ /// Creates an instance that represents absolute value.
56
+ /// </summary>
57
+ /// <returns>Absolute value</returns>
58
+ uint16_t AbsImpl() const noexcept {
59
+ return static_cast<uint16_t>(val & ~kSignMask);
60
+ }
61
+
62
+ /// <summary>
63
+ /// Creates a new instance with the sign flipped.
64
+ /// </summary>
65
+ /// <returns>Flipped sign instance</returns>
66
+ uint16_t NegateImpl() const noexcept {
67
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
68
+ }
69
+
70
+ public:
71
+ // uint16_t special values
72
+ static constexpr uint16_t kSignMask = 0x8000U;
73
+ static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
74
+ static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
75
+ static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
76
+ static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
77
+ static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
78
+ static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
79
+ static constexpr uint16_t kOneBits = 0x3C00U;
80
+ static constexpr uint16_t kMinusOneBits = 0xBC00U;
81
+
82
+ uint16_t val{0};
83
+
84
+ Float16Impl() = default;
85
+
86
+ /// <summary>
87
+ /// Checks if the value is negative
88
+ /// </summary>
89
+ /// <returns>true if negative</returns>
90
+ bool IsNegative() const noexcept {
91
+ return static_cast<int16_t>(val) < 0;
92
+ }
93
+
94
+ /// <summary>
95
+ /// Tests if the value is NaN
96
+ /// </summary>
97
+ /// <returns>true if NaN</returns>
98
+ bool IsNaN() const noexcept {
99
+ return AbsImpl() > kPositiveInfinityBits;
100
+ }
101
+
102
+ /// <summary>
103
+ /// Tests if the value is finite
104
+ /// </summary>
105
+ /// <returns>true if finite</returns>
106
+ bool IsFinite() const noexcept {
107
+ return AbsImpl() < kPositiveInfinityBits;
108
+ }
109
+
110
+ /// <summary>
111
+ /// Tests if the value represents positive infinity.
112
+ /// </summary>
113
+ /// <returns>true if positive infinity</returns>
114
+ bool IsPositiveInfinity() const noexcept {
115
+ return val == kPositiveInfinityBits;
116
+ }
117
+
118
+ /// <summary>
119
+ /// Tests if the value represents negative infinity
120
+ /// </summary>
121
+ /// <returns>true if negative infinity</returns>
122
+ bool IsNegativeInfinity() const noexcept {
123
+ return val == kNegativeInfinityBits;
124
+ }
125
+
126
+ /// <summary>
127
+ /// Tests if the value is either positive or negative infinity.
128
+ /// </summary>
129
+ /// <returns>True if absolute value is infinity</returns>
130
+ bool IsInfinity() const noexcept {
131
+ return AbsImpl() == kPositiveInfinityBits;
132
+ }
133
+
134
+ /// <summary>
135
+ /// Tests if the value is NaN or zero. Useful for comparisons.
136
+ /// </summary>
137
+ /// <returns>True if NaN or zero.</returns>
138
+ bool IsNaNOrZero() const noexcept {
139
+ auto abs = AbsImpl();
140
+ return (abs == 0 || abs > kPositiveInfinityBits);
141
+ }
142
+
143
+ /// <summary>
144
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
145
+ /// </summary>
146
+ /// <returns>True if so</returns>
147
+ bool IsNormal() const noexcept {
148
+ auto abs = AbsImpl();
149
+ return (abs < kPositiveInfinityBits) // is finite
150
+ && (abs != 0) // is not zero
151
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
152
+ }
153
+
154
+ /// <summary>
155
+ /// Tests if the value is subnormal (denormal).
156
+ /// </summary>
157
+ /// <returns>True if so</returns>
158
+ bool IsSubnormal() const noexcept {
159
+ auto abs = AbsImpl();
160
+ return (abs < kPositiveInfinityBits) // is finite
161
+ && (abs != 0) // is not zero
162
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
163
+ }
164
+
165
+ /// <summary>
166
+ /// Creates an instance that represents absolute value.
167
+ /// </summary>
168
+ /// <returns>Absolute value</returns>
169
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
170
+
171
+ /// <summary>
172
+ /// Creates a new instance with the sign flipped.
173
+ /// </summary>
174
+ /// <returns>Flipped sign instance</returns>
175
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
176
+
177
+ /// <summary>
178
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
179
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
180
+ /// and therefore equivalent, if the resulting value is still zero.
181
+ /// </summary>
182
+ /// <param name="lhs">first value</param>
183
+ /// <param name="rhs">second value</param>
184
+ /// <returns>True if both arguments represent zero</returns>
185
+ static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
186
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
187
+ }
188
+
189
+ bool operator==(const Float16Impl& rhs) const noexcept {
190
+ if (IsNaN() || rhs.IsNaN()) {
191
+ // IEEE defines that NaN is not equal to anything, including itself.
192
+ return false;
193
+ }
194
+ return val == rhs.val;
195
+ }
196
+
197
+ bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
198
+
199
+ bool operator<(const Float16Impl& rhs) const noexcept {
200
+ if (IsNaN() || rhs.IsNaN()) {
201
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
202
+ return false;
203
+ }
204
+
205
+ const bool left_is_negative = IsNegative();
206
+ if (left_is_negative != rhs.IsNegative()) {
207
+ // When the signs of left and right differ, we know that left is less than right if it is
208
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
209
+ // says they should be equal, even if the signs differ.
210
+ return left_is_negative && !AreZero(*this, rhs);
211
+ }
212
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
213
+ }
214
+ };
215
+
216
+ // The following Float16_t conversions are based on the code from
217
+ // Eigen library.
218
+
219
+ // The conversion routines are Copyright (c) Fabian Giesen, 2016.
220
+ // The original license follows:
221
+ //
222
+ // Copyright (c) Fabian Giesen, 2016
223
+ // All rights reserved.
224
+ // Redistribution and use in source and binary forms, with or without
225
+ // modification, are permitted.
226
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
227
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
228
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
229
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
230
+ // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
231
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
232
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
233
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
234
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
235
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
236
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+ namespace detail {
239
+ union float32_bits {
240
+ unsigned int u;
241
+ float f;
242
+ };
243
+ } // namespace detail
244
+
245
+ template <class Derived>
246
+ inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
247
+ detail::float32_bits f{};
248
+ f.f = v;
249
+
250
+ constexpr detail::float32_bits f32infty = {255 << 23};
251
+ constexpr detail::float32_bits f16max = {(127 + 16) << 23};
252
+ constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
253
+ constexpr unsigned int sign_mask = 0x80000000u;
254
+ uint16_t val = static_cast<uint16_t>(0x0u);
255
+
256
+ unsigned int sign = f.u & sign_mask;
257
+ f.u ^= sign;
258
+
259
+ // NOTE all the integer compares in this function can be safely
260
+ // compiled into signed compares since all operands are below
261
+ // 0x80000000. Important if you want fast straight SSE2 code
262
+ // (since there's no unsigned PCMPGTD).
263
+
264
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
265
+ val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
266
+ } else { // (De)normalized number or zero
267
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
268
+ // use a magic value to align our 10 mantissa bits at the bottom of
269
+ // the float. as long as FP addition is round-to-nearest-even this
270
+ // just works.
271
+ f.f += denorm_magic.f;
272
+
273
+ // and one integer subtract of the bias later, we have our final float!
274
+ val = static_cast<uint16_t>(f.u - denorm_magic.u);
275
+ } else {
276
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
277
+
278
+ // update exponent, rounding bias part 1
279
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
280
+ // without arithmetic overflow.
281
+ f.u += 0xc8000fffU;
282
+ // rounding bias part 2
283
+ f.u += mant_odd;
284
+ // take the bits!
285
+ val = static_cast<uint16_t>(f.u >> 13);
286
+ }
287
+ }
288
+
289
+ val |= static_cast<uint16_t>(sign >> 16);
290
+ return val;
291
+ }
292
+
293
+ template <class Derived>
294
+ inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
295
+ constexpr detail::float32_bits magic = {113 << 23};
296
+ constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
297
+ detail::float32_bits o{};
298
+
299
+ o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
300
+ unsigned int exp = shifted_exp & o.u; // just the exponent
301
+ o.u += (127 - 15) << 23; // exponent adjust
302
+
303
+ // handle exponent special cases
304
+ if (exp == shifted_exp) { // Inf/NaN?
305
+ o.u += (128 - 16) << 23; // extra exp adjust
306
+ } else if (exp == 0) { // Zero/Denormal?
307
+ o.u += 1 << 23; // extra exp adjust
308
+ o.f -= magic.f; // re-normalize
309
+ }
310
+
311
+ // Attempt to workaround the Internal Compiler Error on ARM64
312
+ // for bitwise | operator, including std::bitset
313
+ #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
314
+ if (IsNegative()) {
315
+ return -o.f;
316
+ }
317
+ #else
318
+ // original code:
319
+ o.u |= (val & 0x8000U) << 16U; // sign bit
320
+ #endif
321
+ return o.f;
322
+ }
323
+
324
+ /// Shared implementation between public and internal classes. CRTP pattern.
325
+ template <class Derived>
326
+ struct BFloat16Impl {
327
+ protected:
328
+ /// <summary>
329
+ /// Converts from float to uint16_t float16 representation
330
+ /// </summary>
331
+ /// <param name="v"></param>
332
+ /// <returns></returns>
333
+ static uint16_t ToUint16Impl(float v) noexcept;
334
+
335
+ /// <summary>
336
+ /// Converts bfloat16 to float
337
+ /// </summary>
338
+ /// <returns>float representation of bfloat16 value</returns>
339
+ float ToFloatImpl() const noexcept;
340
+
341
+ /// <summary>
342
+ /// Creates an instance that represents absolute value.
343
+ /// </summary>
344
+ /// <returns>Absolute value</returns>
345
+ uint16_t AbsImpl() const noexcept {
346
+ return static_cast<uint16_t>(val & ~kSignMask);
347
+ }
348
+
349
+ /// <summary>
350
+ /// Creates a new instance with the sign flipped.
351
+ /// </summary>
352
+ /// <returns>Flipped sign instance</returns>
353
+ uint16_t NegateImpl() const noexcept {
354
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
355
+ }
356
+
357
+ public:
358
+ // uint16_t special values
359
+ static constexpr uint16_t kSignMask = 0x8000U;
360
+ static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
361
+ static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
362
+ static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
363
+ static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
364
+ static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
365
+ static constexpr uint16_t kMaxValueBits = 0x7F7FU;
366
+ static constexpr uint16_t kRoundToNearest = 0x7FFFU;
367
+ static constexpr uint16_t kOneBits = 0x3F80U;
368
+ static constexpr uint16_t kMinusOneBits = 0xBF80U;
369
+
370
+ uint16_t val{0};
371
+
372
+ BFloat16Impl() = default;
373
+
374
+ /// <summary>
375
+ /// Checks if the value is negative
376
+ /// </summary>
377
+ /// <returns>true if negative</returns>
378
+ bool IsNegative() const noexcept {
379
+ return static_cast<int16_t>(val) < 0;
380
+ }
381
+
382
+ /// <summary>
383
+ /// Tests if the value is NaN
384
+ /// </summary>
385
+ /// <returns>true if NaN</returns>
386
+ bool IsNaN() const noexcept {
387
+ return AbsImpl() > kPositiveInfinityBits;
388
+ }
389
+
390
+ /// <summary>
391
+ /// Tests if the value is finite
392
+ /// </summary>
393
+ /// <returns>true if finite</returns>
394
+ bool IsFinite() const noexcept {
395
+ return AbsImpl() < kPositiveInfinityBits;
396
+ }
397
+
398
+ /// <summary>
399
+ /// Tests if the value represents positive infinity.
400
+ /// </summary>
401
+ /// <returns>true if positive infinity</returns>
402
+ bool IsPositiveInfinity() const noexcept {
403
+ return val == kPositiveInfinityBits;
404
+ }
405
+
406
+ /// <summary>
407
+ /// Tests if the value represents negative infinity
408
+ /// </summary>
409
+ /// <returns>true if negative infinity</returns>
410
+ bool IsNegativeInfinity() const noexcept {
411
+ return val == kNegativeInfinityBits;
412
+ }
413
+
414
+ /// <summary>
415
+ /// Tests if the value is either positive or negative infinity.
416
+ /// </summary>
417
+ /// <returns>True if absolute value is infinity</returns>
418
+ bool IsInfinity() const noexcept {
419
+ return AbsImpl() == kPositiveInfinityBits;
420
+ }
421
+
422
+ /// <summary>
423
+ /// Tests if the value is NaN or zero. Useful for comparisons.
424
+ /// </summary>
425
+ /// <returns>True if NaN or zero.</returns>
426
+ bool IsNaNOrZero() const noexcept {
427
+ auto abs = AbsImpl();
428
+ return (abs == 0 || abs > kPositiveInfinityBits);
429
+ }
430
+
431
+ /// <summary>
432
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
433
+ /// </summary>
434
+ /// <returns>True if so</returns>
435
+ bool IsNormal() const noexcept {
436
+ auto abs = AbsImpl();
437
+ return (abs < kPositiveInfinityBits) // is finite
438
+ && (abs != 0) // is not zero
439
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
440
+ }
441
+
442
+ /// <summary>
443
+ /// Tests if the value is subnormal (denormal).
444
+ /// </summary>
445
+ /// <returns>True if so</returns>
446
+ bool IsSubnormal() const noexcept {
447
+ auto abs = AbsImpl();
448
+ return (abs < kPositiveInfinityBits) // is finite
449
+ && (abs != 0) // is not zero
450
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
451
+ }
452
+
453
+ /// <summary>
454
+ /// Creates an instance that represents absolute value.
455
+ /// </summary>
456
+ /// <returns>Absolute value</returns>
457
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
458
+
459
+ /// <summary>
460
+ /// Creates a new instance with the sign flipped.
461
+ /// </summary>
462
+ /// <returns>Flipped sign instance</returns>
463
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
464
+
465
+ /// <summary>
466
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
467
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
468
+ /// and therefore equivalent, if the resulting value is still zero.
469
+ /// </summary>
470
+ /// <param name="lhs">first value</param>
471
+ /// <param name="rhs">second value</param>
472
+ /// <returns>True if both arguments represent zero</returns>
473
+ static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
474
+ // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
475
+ // for two values by or'ing the private bits together and stripping the sign. They are both zero,
476
+ // and therefore equivalent, if the resulting value is still zero.
477
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
478
+ }
479
+ };
480
+
481
+ template <class Derived>
482
+ inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
483
+ uint16_t result;
484
+ if (std::isnan(v)) {
485
+ result = kPositiveQNaNBits;
486
+ } else {
487
+ auto get_msb_half = [](float fl) {
488
+ uint16_t result;
489
+ #ifdef __cpp_if_constexpr
490
+ if constexpr (detail::endian::native == detail::endian::little) {
491
+ #else
492
+ if (detail::endian::native == detail::endian::little) {
493
+ #endif
494
+ std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
495
+ } else {
496
+ std::memcpy(&result, &fl, sizeof(uint16_t));
497
+ }
498
+ return result;
499
+ };
500
+
501
+ uint16_t upper_bits = get_msb_half(v);
502
+ union {
503
+ uint32_t U32;
504
+ float F32;
505
+ };
506
+ F32 = v;
507
+ U32 += (upper_bits & 1) + kRoundToNearest;
508
+ result = get_msb_half(F32);
509
+ }
510
+ return result;
511
+ }
512
+
513
+ template <class Derived>
514
+ inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
515
+ if (IsNaN()) {
516
+ return std::numeric_limits<float>::quiet_NaN();
517
+ }
518
+ float result;
519
+ char* const first = reinterpret_cast<char*>(&result);
520
+ char* const second = first + sizeof(uint16_t);
521
+ #ifdef __cpp_if_constexpr
522
+ if constexpr (detail::endian::native == detail::endian::little) {
523
+ #else
524
+ if (detail::endian::native == detail::endian::little) {
525
+ #endif
526
+ std::memset(first, 0, sizeof(uint16_t));
527
+ std::memcpy(second, &val, sizeof(uint16_t));
528
+ } else {
529
+ std::memcpy(first, &val, sizeof(uint16_t));
530
+ std::memset(second, 0, sizeof(uint16_t));
531
+ }
532
+ return result;
533
+ }
534
+
535
+ } // namespace onnxruntime_float16
v1.23.1/headers/onnxruntime_lite_custom_op.h ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Summary
5
+ // The header has APIs to save custom op authors the trouble of defining schemas,
6
+ // which will be inferred by functions' signature, as long as their argument list has types supported here.
7
+ // Input could be:
8
+ // 1. Tensor of onnx data types.
9
+ // 2. Span of onnx data types.
10
+ // 3. Scalar of onnx data types.
11
+ // A input could be optional if indicated as std::optional<...>.
12
+ // For an output, it must be a tensor of onnx data types.
13
+ // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
14
+ // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
15
+ // Note - all APIs in this header are ABI.
16
+
17
+ #pragma once
18
+ #include "onnxruntime_cxx_api.h"
19
+ #include <optional>
20
+ #include <numeric>
21
+ #include <functional>
22
+ #include <unordered_set>
23
+
24
+ namespace Ort {
25
+ namespace Custom {
26
+
27
+ class ArgBase {
28
+ public:
29
+ ArgBase(OrtKernelContext* ctx,
30
+ size_t indice,
31
+ bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
32
+ virtual ~ArgBase() {};
33
+
34
+ protected:
35
+ struct KernelContext ctx_;
36
+ size_t indice_;
37
+ bool is_input_;
38
+ };
39
+
40
+ using ArgPtr = std::unique_ptr<Custom::ArgBase>;
41
+ using ArgPtrs = std::vector<ArgPtr>;
42
+
43
+ class TensorBase : public ArgBase {
44
+ public:
45
+ TensorBase(OrtKernelContext* ctx,
46
+ size_t indice,
47
+ bool is_input) : ArgBase(ctx, indice, is_input) {}
48
+
49
+ operator bool() const {
50
+ return shape_.has_value();
51
+ }
52
+
53
+ const std::vector<int64_t>& Shape() const {
54
+ if (!shape_.has_value()) {
55
+ ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
56
+ }
57
+ return shape_.value();
58
+ }
59
+
60
+ ONNXTensorElementDataType Type() const {
61
+ return type_;
62
+ }
63
+
64
+ int64_t NumberOfElement() const {
65
+ if (shape_.has_value()) {
66
+ return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
67
+ } else {
68
+ return 0;
69
+ }
70
+ }
71
+
72
+ std::string Shape2Str() const {
73
+ if (shape_.has_value()) {
74
+ std::string shape_str;
75
+ for (const auto& dim : *shape_) {
76
+ shape_str.append(std::to_string(dim));
77
+ shape_str.append(", ");
78
+ }
79
+ return shape_str;
80
+ } else {
81
+ return "empty";
82
+ }
83
+ }
84
+
85
+ bool IsCpuTensor() const {
86
+ return strcmp("Cpu", mem_type_) == 0;
87
+ }
88
+
89
+ virtual const void* DataRaw() const = 0;
90
+ virtual size_t SizeInBytes() const = 0;
91
+
92
+ protected:
93
+ std::optional<std::vector<int64_t>> shape_;
94
+ ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
95
+ const char* mem_type_ = "Cpu";
96
+ };
97
+
98
+ template <typename T>
99
+ struct Span {
100
+ const T* data_ = {};
101
+ size_t size_ = {};
102
+ void Assign(const T* data, size_t size) {
103
+ data_ = data;
104
+ size_ = size;
105
+ }
106
+ size_t size() const { return size_; }
107
+ T operator[](size_t indice) const {
108
+ return data_[indice];
109
+ }
110
+ const T* data() const { return data_; }
111
+ };
112
+
113
+ template <typename T>
114
+ class Tensor : public TensorBase {
115
+ public:
116
+ using TT = typename std::remove_reference<T>::type;
117
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
118
+ if (is_input_) {
119
+ if (indice >= ctx_.GetInputCount()) {
120
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
121
+ }
122
+ const_value_ = ctx_.GetInput(indice);
123
+ auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
124
+ shape_ = type_shape_info.GetShape();
125
+ }
126
+ }
127
+ const TT* Data() const {
128
+ return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
129
+ }
130
+ TT* Allocate(const std::vector<int64_t>& shape) {
131
+ shape_ = shape;
132
+ if (!data_) {
133
+ shape_ = shape;
134
+ data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
135
+ }
136
+ return data_;
137
+ }
138
+ static TT GetT() { return (TT)0; }
139
+ const Span<T>& AsSpan() {
140
+ if (!shape_.has_value() || shape_->size() != 1) {
141
+ ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
142
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
143
+ }
144
+ span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
145
+ return span_;
146
+ }
147
+ const T& AsScalar() {
148
+ if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
149
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
150
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
151
+ }
152
+ return *Data();
153
+ }
154
+ const void* DataRaw() const override {
155
+ return reinterpret_cast<const void*>(Data());
156
+ }
157
+
158
+ size_t SizeInBytes() const override {
159
+ return sizeof(TT) * static_cast<size_t>(NumberOfElement());
160
+ }
161
+
162
+ private:
163
+ ConstValue const_value_; // for input
164
+ TT* data_{}; // for output
165
+ Span<T> span_;
166
+ };
167
+
168
+ template <>
169
+ class Tensor<std::string> : public TensorBase {
170
+ public:
171
+ using strings = std::vector<std::string>;
172
+
173
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
174
+ if (is_input_) {
175
+ if (indice >= ctx_.GetInputCount()) {
176
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
177
+ }
178
+ auto const_value = ctx_.GetInput(indice);
179
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
180
+ shape_ = type_shape_info.GetShape();
181
+ auto num_chars = const_value.GetStringTensorDataLength();
182
+ // note - there will be copy ...
183
+ auto num_strings = static_cast<size_t>(NumberOfElement());
184
+ if (num_strings) {
185
+ std::vector<char> chars(num_chars + 1, '\0');
186
+ std::vector<size_t> offsets(num_strings);
187
+ const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
188
+ auto upper_bound = num_strings - 1;
189
+ input_strings_.resize(num_strings);
190
+ for (size_t i = upper_bound;; --i) {
191
+ if (i < upper_bound) {
192
+ chars[offsets[i + 1]] = '\0';
193
+ }
194
+ input_strings_[i] = chars.data() + offsets[i];
195
+ if (0 == i) {
196
+ break;
197
+ }
198
+ }
199
+ }
200
+ }
201
+ }
202
+ const strings& Data() const {
203
+ return input_strings_;
204
+ }
205
+ const void* DataRaw() const override {
206
+ if (input_strings_.size() != 1) {
207
+ ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
208
+ }
209
+ return reinterpret_cast<const void*>(input_strings_[0].c_str());
210
+ }
211
+ size_t SizeInBytes() const override {
212
+ if (input_strings_.size() != 1) {
213
+ ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
214
+ }
215
+ return input_strings_[0].size();
216
+ }
217
+ void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
218
+ shape_ = dims;
219
+ std::vector<const char*> raw;
220
+ for (const auto& s : ss) {
221
+ raw.push_back(s.data());
222
+ }
223
+ auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
224
+ // note - there will be copy ...
225
+ output.FillStringTensor(raw.data(), raw.size());
226
+ }
227
+ const Span<std::string>& AsSpan() {
228
+ ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
229
+ }
230
+ const std::string& AsScalar() {
231
+ if (input_strings_.size() != 1) {
232
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
233
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
234
+ }
235
+ return input_strings_[0];
236
+ }
237
+
238
+ private:
239
+ std::vector<std::string> input_strings_; // for input
240
+ };
241
+
242
+ template <>
243
+ class Tensor<std::string_view> : public TensorBase {
244
+ public:
245
+ using strings = std::vector<std::string>;
246
+ using string_views = std::vector<std::string_view>;
247
+
248
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
249
+ if (is_input_) {
250
+ if (indice >= ctx_.GetInputCount()) {
251
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
252
+ }
253
+ auto const_value = ctx_.GetInput(indice);
254
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
255
+ shape_ = type_shape_info.GetShape();
256
+ auto num_chars = const_value.GetStringTensorDataLength();
257
+ chars_.resize(num_chars + 1, '\0');
258
+ auto num_strings = static_cast<size_t>(NumberOfElement());
259
+ if (num_strings) {
260
+ std::vector<size_t> offsets(num_strings);
261
+ const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
262
+ offsets.push_back(num_chars);
263
+ for (size_t i = 0; i < num_strings; ++i) {
264
+ input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
265
+ }
266
+ }
267
+ }
268
+ }
269
+ const string_views& Data() const {
270
+ return input_string_views_;
271
+ }
272
+ const void* DataRaw() const override {
273
+ if (input_string_views_.size() != 1) {
274
+ ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
275
+ }
276
+ return reinterpret_cast<const void*>(input_string_views_[0].data());
277
+ }
278
+ size_t SizeInBytes() const override {
279
+ if (input_string_views_.size() != 1) {
280
+ ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
281
+ }
282
+ return input_string_views_[0].size();
283
+ }
284
+ void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
285
+ shape_ = dims;
286
+ std::vector<const char*> raw;
287
+ for (const auto& s : ss) {
288
+ raw.push_back(s.data());
289
+ }
290
+ auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
291
+ // note - there will be copy ...
292
+ output.FillStringTensor(raw.data(), raw.size());
293
+ }
294
+ const Span<std::string_view>& AsSpan() {
295
+ ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
296
+ }
297
+ std::string_view AsScalar() {
298
+ if (input_string_views_.size() != 1) {
299
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
300
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
301
+ }
302
+ return input_string_views_[0];
303
+ }
304
+
305
+ private:
306
+ std::vector<char> chars_; // for input
307
+ std::vector<std::string_view> input_string_views_; // for input
308
+ };
309
+
310
+ using TensorPtr = std::unique_ptr<Custom::TensorBase>;
311
+ using TensorPtrs = std::vector<TensorPtr>;
312
+
313
+ struct TensorArray : public ArgBase {
314
+ TensorArray(OrtKernelContext* ctx,
315
+ size_t start_indice,
316
+ bool is_input) : ArgBase(ctx,
317
+ start_indice,
318
+ is_input) {
319
+ if (is_input) {
320
+ auto input_count = ctx_.GetInputCount();
321
+ for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
322
+ auto const_value = ctx_.GetInput(start_indice);
323
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
324
+ auto type = type_shape_info.GetElementType();
325
+ TensorPtr tensor;
326
+ switch (type) {
327
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
328
+ tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
329
+ break;
330
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
331
+ tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
332
+ break;
333
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
334
+ tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
335
+ break;
336
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
337
+ tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
338
+ break;
339
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
340
+ tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
341
+ break;
342
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
343
+ tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
344
+ break;
345
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
346
+ tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
347
+ break;
348
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
349
+ tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
350
+ break;
351
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
352
+ tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
353
+ break;
354
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
355
+ tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
356
+ break;
357
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
358
+ tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
359
+ break;
360
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
361
+ tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
362
+ break;
363
+ default:
364
+ ORT_CXX_API_THROW("unknown input type", ORT_RUNTIME_EXCEPTION);
365
+ break;
366
+ }
367
+ tensors_.emplace_back(tensor.release());
368
+ } // for
369
+ }
370
+ }
371
+ template <typename T>
372
+ T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
373
+ // ith_output is the indice of output relative to the tensor array
374
+ // indice_ + ith_output is the indice relative to context
375
+ auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
376
+ auto raw_output = tensor.get()->Allocate(shape);
377
+ tensors_.emplace_back(tensor.release());
378
+ return raw_output;
379
+ }
380
+ Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
381
+ // ith_output is the indice of output relative to the tensor array
382
+ // indice_ + ith_output is the indice relative to context
383
+ auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
384
+ Tensor<std::string>& output = *tensor;
385
+ tensors_.emplace_back(tensor.release());
386
+ return output;
387
+ }
388
+ size_t Size() const {
389
+ return tensors_.size();
390
+ }
391
+ const TensorPtr& operator[](size_t ith_input) const {
392
+ // ith_input is the indice of output relative to the tensor array
393
+ return tensors_.at(ith_input);
394
+ }
395
+
396
+ private:
397
+ TensorPtrs tensors_;
398
+ };
399
+
400
+ using Variadic = TensorArray;
401
+
402
+ /*
403
+ Note:
404
+ OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
405
+ The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
406
+ 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy.
407
+ 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
408
+ hence memory could still be recycled properly.
409
+ Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
410
+ */
411
+ struct OrtLiteCustomOp : public OrtCustomOp {
412
+ using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
413
+ using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
414
+
415
+ // CreateTuple
416
+ template <size_t ith_input, size_t ith_output, typename... Ts>
417
+ static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
418
+ CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
419
+ return std::make_tuple();
420
+ }
421
+
422
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
423
+ static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
424
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
425
+ std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
426
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
427
+ return std::tuple_cat(current, next);
428
+ }
429
+
430
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
431
+ static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
432
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
433
+ std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
434
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
435
+ return std::tuple_cat(current, next);
436
+ }
437
+
438
+ #ifdef ORT_CUDA_CTX
439
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
440
+ static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
441
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
442
+ thread_local CudaContext cuda_context;
443
+ cuda_context.Init(*context);
444
+ std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
445
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
446
+ return std::tuple_cat(current, next);
447
+ }
448
+ #endif
449
+
450
+ #ifdef ORT_ROCM_CTX
451
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
452
+ static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
453
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
454
+ thread_local RocmContext rocm_context;
455
+ rocm_context.Init(*context);
456
+ std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
457
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
458
+ return std::tuple_cat(current, next);
459
+ }
460
+ #endif
461
+
462
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
463
+ static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
464
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
465
+ args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
466
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
467
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
468
+ return std::tuple_cat(current, next);
469
+ }
470
+
471
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
472
+ static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
473
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
474
+ args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
475
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
476
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
477
+ return std::tuple_cat(current, next);
478
+ }
479
+
480
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
481
+ static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
482
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
483
+ args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
484
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
485
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
486
+ return std::tuple_cat(current, next);
487
+ }
488
+
489
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
490
+ static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
491
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
492
+ args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
493
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
494
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
495
+ return std::tuple_cat(current, next);
496
+ }
497
+
498
+ #define CREATE_TUPLE_INPUT(data_type) \
499
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
500
+ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
501
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
502
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
503
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
504
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
505
+ return std::tuple_cat(current, next); \
506
+ } \
507
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
508
+ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
509
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
510
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
511
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
512
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
513
+ return std::tuple_cat(current, next); \
514
+ } \
515
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
516
+ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
517
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
518
+ if (ith_input < num_input) { \
519
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
520
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
521
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
522
+ return std::tuple_cat(current, next); \
523
+ } else { \
524
+ std::tuple<T> current = std::tuple<T>{}; \
525
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
526
+ return std::tuple_cat(current, next); \
527
+ } \
528
+ } \
529
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
530
+ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
531
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
532
+ if ("CPUExecutionProvider" != ep) { \
533
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
534
+ } \
535
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
536
+ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
537
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
538
+ return std::tuple_cat(current, next); \
539
+ } \
540
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
541
+ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
542
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
543
+ if ("CPUExecutionProvider" != ep) { \
544
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
545
+ } \
546
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
547
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
548
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
549
+ return std::tuple_cat(current, next); \
550
+ } \
551
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
552
+ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
553
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
554
+ if (ith_input < num_input) { \
555
+ if ("CPUExecutionProvider" != ep) { \
556
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
557
+ } \
558
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
559
+ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
560
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
561
+ return std::tuple_cat(current, next); \
562
+ } else { \
563
+ std::tuple<T> current = std::tuple<T>{}; \
564
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
565
+ return std::tuple_cat(current, next); \
566
+ } \
567
+ } \
568
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
569
+ static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
570
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
571
+ if ("CPUExecutionProvider" != ep) { \
572
+ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
573
+ } \
574
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
575
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
576
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
577
+ return std::tuple_cat(current, next); \
578
+ } \
579
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
580
+ static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
581
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
582
+ if (ith_input < num_input) { \
583
+ if ("CPUExecutionProvider" != ep) { \
584
+ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
585
+ } \
586
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
587
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
588
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
589
+ return std::tuple_cat(current, next); \
590
+ } else { \
591
+ std::tuple<T> current = std::tuple<T>{}; \
592
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
593
+ return std::tuple_cat(current, next); \
594
+ } \
595
+ }
596
+ #define CREATE_TUPLE_OUTPUT(data_type) \
597
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
598
+ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
599
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
600
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
601
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
602
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
603
+ return std::tuple_cat(current, next); \
604
+ } \
605
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
606
+ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
607
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
608
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
609
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
610
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
611
+ return std::tuple_cat(current, next); \
612
+ } \
613
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
614
+ static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
615
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
616
+ if (ith_output < num_output) { \
617
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
618
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
619
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
620
+ return std::tuple_cat(current, next); \
621
+ } else { \
622
+ std::tuple<T> current = std::tuple<T>{}; \
623
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
624
+ return std::tuple_cat(current, next); \
625
+ } \
626
+ }
627
+ #define CREATE_TUPLE(data_type) \
628
+ CREATE_TUPLE_INPUT(data_type) \
629
+ CREATE_TUPLE_OUTPUT(data_type)
630
+
631
+ CREATE_TUPLE(bool)
632
+ CREATE_TUPLE(float)
633
+ CREATE_TUPLE(Ort::Float16_t)
634
+ CREATE_TUPLE(Ort::BFloat16_t)
635
+ CREATE_TUPLE(double)
636
+ CREATE_TUPLE(int8_t)
637
+ CREATE_TUPLE(int16_t)
638
+ CREATE_TUPLE(int32_t)
639
+ CREATE_TUPLE(int64_t)
640
+ CREATE_TUPLE(uint8_t)
641
+ CREATE_TUPLE(uint16_t)
642
+ CREATE_TUPLE(uint32_t)
643
+ CREATE_TUPLE(uint64_t)
644
+ CREATE_TUPLE(std::string)
645
+ CREATE_TUPLE_INPUT(std::string_view)
646
+ CREATE_TUPLE(Ort::Float8E4M3FN_t)
647
+ CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
648
+ CREATE_TUPLE(Ort::Float8E5M2_t)
649
+ CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
650
+
651
+ // ParseArgs ...
652
+ template <typename... Ts>
653
+ static typename std::enable_if<0 == sizeof...(Ts)>::type
654
+ ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
655
+ }
656
+
657
+ template <typename T, typename... Ts>
658
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
659
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
660
+ ParseArgs<Ts...>(input_types, output_types);
661
+ }
662
+
663
+ template <typename T, typename... Ts>
664
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
665
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
666
+ ParseArgs<Ts...>(input_types, output_types);
667
+ }
668
+
669
+ #ifdef ORT_CUDA_CTX
670
+ template <typename T, typename... Ts>
671
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
672
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
673
+ ParseArgs<Ts...>(input_types, output_types);
674
+ }
675
+ #endif
676
+
677
+ #ifdef ORT_ROCM_CTX
678
+ template <typename T, typename... Ts>
679
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
680
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
681
+ ParseArgs<Ts...>(input_types, output_types);
682
+ }
683
+ #endif
684
+
685
+ template <typename T, typename... Ts>
686
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
687
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
688
+ input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
689
+ ParseArgs<Ts...>(input_types, output_types);
690
+ }
691
+
692
+ template <typename T, typename... Ts>
693
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
694
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
695
+ input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
696
+ ParseArgs<Ts...>(input_types, output_types);
697
+ }
698
+
699
+ template <typename T, typename... Ts>
700
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
701
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
702
+ output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
703
+ ParseArgs<Ts...>(input_types, output_types);
704
+ }
705
+
706
+ template <typename T, typename... Ts>
707
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
708
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
709
+ output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
710
+ ParseArgs<Ts...>(input_types, output_types);
711
+ }
712
+
713
+ #define PARSE_INPUT_BASE(pack_type, onnx_type) \
714
+ template <typename T, typename... Ts> \
715
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
716
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
717
+ input_types.push_back(onnx_type); \
718
+ ParseArgs<Ts...>(input_types, output_types); \
719
+ } \
720
+ template <typename T, typename... Ts> \
721
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
722
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
723
+ input_types.push_back(onnx_type); \
724
+ ParseArgs<Ts...>(input_types, output_types); \
725
+ } \
726
+ template <typename T, typename... Ts> \
727
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
728
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
729
+ input_types.push_back(onnx_type); \
730
+ ParseArgs<Ts...>(input_types, output_types); \
731
+ }
732
+
733
+ #define PARSE_INPUT(data_type, onnx_type) \
734
+ PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
735
+ PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
736
+ PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
737
+ PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
738
+ PARSE_INPUT_BASE(data_type, onnx_type)
739
+
740
+ #define PARSE_OUTPUT(data_type, onnx_type) \
741
+ template <typename T, typename... Ts> \
742
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
743
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
744
+ output_types.push_back(onnx_type); \
745
+ ParseArgs<Ts...>(input_types, output_types); \
746
+ } \
747
+ template <typename T, typename... Ts> \
748
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
749
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
750
+ output_types.push_back(onnx_type); \
751
+ ParseArgs<Ts...>(input_types, output_types); \
752
+ } \
753
+ template <typename T, typename... Ts> \
754
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
755
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
756
+ output_types.push_back(onnx_type); \
757
+ ParseArgs<Ts...>(input_types, output_types); \
758
+ }
759
+
760
+ #define PARSE_ARGS(data_type, onnx_type) \
761
+ PARSE_INPUT(data_type, onnx_type) \
762
+ PARSE_OUTPUT(data_type, onnx_type)
763
+
764
+ PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
765
+ PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
766
+ PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
767
+ PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
768
+ PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
769
+ PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
770
+ PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
771
+ PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
772
+ PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
773
+ PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
774
+ PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
775
+ PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
776
+ PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
777
+ PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
778
+ PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
779
+ PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
780
+ PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
781
+ PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
782
+ PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
783
+
784
+ OrtLiteCustomOp(const char* op_name,
785
+ const char* execution_provider,
786
+ ShapeInferFn shape_infer_fn,
787
+ int start_ver = 1,
788
+ int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
789
+ execution_provider_(execution_provider),
790
+ shape_infer_fn_(shape_infer_fn),
791
+ start_ver_(start_ver),
792
+ end_ver_(end_ver) {
793
+ OrtCustomOp::version = ORT_API_VERSION;
794
+
795
+ OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
796
+ OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
797
+ OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
798
+
799
+ OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
800
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
801
+ return self->input_types_.size();
802
+ };
803
+
804
+ OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
805
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
806
+ return self->input_types_[indice];
807
+ };
808
+
809
+ OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
810
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
811
+ return self->output_types_.size();
812
+ };
813
+
814
+ OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
815
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
816
+ return self->output_types_[indice];
817
+ };
818
+
819
+ OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
820
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
821
+ return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
822
+ };
823
+
824
+ OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
825
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
826
+ return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
827
+ };
828
+
829
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
830
+ return 1;
831
+ };
832
+
833
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
834
+ return 0;
835
+ };
836
+
837
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
838
+ return 1;
839
+ };
840
+
841
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
842
+ return 0;
843
+ };
844
+
845
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
846
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
847
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
848
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
849
+
850
+ OrtCustomOp::CreateKernelV2 = {};
851
+ OrtCustomOp::KernelComputeV2 = {};
852
+ OrtCustomOp::KernelCompute = {};
853
+
854
+ OrtCustomOp::InferOutputShapeFn = {};
855
+
856
+ OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
857
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
858
+ return self->start_ver_;
859
+ };
860
+
861
+ OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
862
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
863
+ return self->end_ver_;
864
+ };
865
+
866
+ OrtCustomOp::GetMayInplace = {};
867
+ OrtCustomOp::ReleaseMayInplace = {};
868
+ OrtCustomOp::GetAliasMap = {};
869
+ OrtCustomOp::ReleaseAliasMap = {};
870
+ }
871
+
872
+ const std::string op_name_;
873
+ const std::string execution_provider_;
874
+
875
+ std::vector<ONNXTensorElementDataType> input_types_;
876
+ std::vector<ONNXTensorElementDataType> output_types_;
877
+
878
+ ShapeInferFn shape_infer_fn_ = {};
879
+
880
+ int start_ver_ = 1;
881
+ int end_ver_ = MAX_CUSTOM_OP_END_VER;
882
+
883
+ void* compute_fn_ = {};
884
+ void* compute_fn_return_status_ = {};
885
+ };
886
+
887
+ //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
888
+ // The struct is to implement function-as-op.
889
+ // E.g. a function might be defined as:
890
+ // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
891
+ // It could be registered this way:
892
+ // Ort::CustomOpDomain v2_domain{"v2"};
893
+ // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
894
+ // v2_domain.Add(fil_op_ptr.get());
895
+ // session_options.Add(v2_domain);
896
+ // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
897
+ template <typename... Args>
898
+ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
899
+ using ComputeFn = void (*)(Args...);
900
+ using ComputeFnReturnStatus = Status (*)(Args...);
901
+ using MyType = OrtLiteCustomFunc<Args...>;
902
+
903
+ struct Kernel {
904
+ size_t num_input_{};
905
+ size_t num_output_{};
906
+ ComputeFn compute_fn_{};
907
+ ComputeFnReturnStatus compute_fn_return_status_{};
908
+ std::string ep_{};
909
+ };
910
+
911
+ OrtLiteCustomFunc(const char* op_name,
912
+ const char* execution_provider,
913
+ ComputeFn compute_fn,
914
+ ShapeInferFn shape_infer_fn = {},
915
+ int start_ver = 1,
916
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
917
+ compute_fn_ = reinterpret_cast<void*>(compute_fn);
918
+ ParseArgs<Args...>(input_types_, output_types_);
919
+
920
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
921
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
922
+ std::vector<ArgPtr> args;
923
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
924
+ std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
925
+ };
926
+
927
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
928
+ auto kernel = std::make_unique<Kernel>();
929
+ auto me = static_cast<const MyType*>(this_);
930
+ kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
931
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
932
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
933
+ auto self = static_cast<const OrtLiteCustomFunc*>(this_);
934
+ kernel->ep_ = self->execution_provider_;
935
+ return reinterpret_cast<void*>(kernel.release());
936
+ };
937
+
938
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
939
+ delete reinterpret_cast<Kernel*>(op_kernel);
940
+ };
941
+
942
+ if (shape_infer_fn_) {
943
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
944
+ auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
945
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
946
+ return shape_info_fn(ctx);
947
+ };
948
+ }
949
+ }
950
+
951
+ OrtLiteCustomFunc(const char* op_name,
952
+ const char* execution_provider,
953
+ ComputeFnReturnStatus compute_fn_return_status,
954
+ ShapeInferFn shape_infer_fn = {},
955
+ int start_ver = 1,
956
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
957
+ compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
958
+ ParseArgs<Args...>(input_types_, output_types_);
959
+
960
+ OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
961
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
962
+ std::vector<ArgPtr> args;
963
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
964
+ return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
965
+ };
966
+
967
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
968
+ auto kernel = std::make_unique<Kernel>();
969
+ auto me = static_cast<const MyType*>(this_);
970
+ kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
971
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
972
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
973
+ auto self = static_cast<const OrtLiteCustomFunc*>(this_);
974
+ kernel->ep_ = self->execution_provider_;
975
+ return reinterpret_cast<void*>(kernel.release());
976
+ };
977
+
978
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
979
+ delete reinterpret_cast<Kernel*>(op_kernel);
980
+ };
981
+
982
+ if (shape_infer_fn_) {
983
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
984
+ auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
985
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
986
+ return shape_info_fn(ctx);
987
+ };
988
+ }
989
+ }
990
+ }; // struct OrtLiteCustomFunc
991
+
992
+ /////////////////////////// OrtLiteCustomStruct ///////////////////////////
993
+ // The struct is to implement struct-as-op.
994
+ // E.g. a struct might be defined as:
995
+ // struct Merge {
996
+ // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
997
+ // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
998
+ // std::string_view string_in,
999
+ // Ort::Custom::Tensor<std::string>* strings_out) {...}
1000
+ // bool reverse_ = false;
1001
+ // };
1002
+ // It could be registered this way:
1003
+ // Ort::CustomOpDomain v2_domain{"v2"};
1004
+ // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
1005
+ // v2_domain.Add(mrg_op_ptr.get());
1006
+ // session_options.Add(v2_domain);
1007
+ // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
1008
+ template <typename CustomOp>
1009
+ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1010
+ template <typename... Args>
1011
+ using CustomComputeFn = void (CustomOp::*)(Args...);
1012
+
1013
+ template <typename... Args>
1014
+ using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
1015
+
1016
+ using MyType = OrtLiteCustomStruct<CustomOp>;
1017
+
1018
+ struct Kernel {
1019
+ size_t num_input_{};
1020
+ size_t num_output_{};
1021
+ std::unique_ptr<CustomOp> custom_op_;
1022
+ std::string ep_{};
1023
+ };
1024
+
1025
+ OrtLiteCustomStruct(const char* op_name,
1026
+ const char* execution_provider,
1027
+ int start_ver = 1,
1028
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
1029
+ SetCompute(&CustomOp::Compute);
1030
+
1031
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1032
+ auto kernel = std::make_unique<Kernel>();
1033
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1034
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1035
+ kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1036
+ auto self = static_cast<const OrtLiteCustomStruct*>(this_);
1037
+ kernel->ep_ = self->execution_provider_;
1038
+ return reinterpret_cast<void*>(kernel.release());
1039
+ };
1040
+
1041
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1042
+ delete reinterpret_cast<Kernel*>(op_kernel);
1043
+ };
1044
+
1045
+ SetShapeInfer<CustomOp>(0);
1046
+ }
1047
+
1048
+ template <typename... Args>
1049
+ void SetCompute(CustomComputeFn<Args...>) {
1050
+ ParseArgs<Args...>(input_types_, output_types_);
1051
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1052
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1053
+ ArgPtrs args;
1054
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1055
+ std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1056
+ };
1057
+ }
1058
+
1059
+ template <typename... Args>
1060
+ void SetCompute(CustomComputeFnReturnStatus<Args...>) {
1061
+ ParseArgs<Args...>(input_types_, output_types_);
1062
+ OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1063
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1064
+ ArgPtrs args;
1065
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1066
+ return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
1067
+ };
1068
+ }
1069
+
1070
+ template <typename C>
1071
+ decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1072
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1073
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
1074
+ return C::InferOutputShape(ctx);
1075
+ };
1076
+ return {};
1077
+ }
1078
+
1079
+ template <typename C>
1080
+ void SetShapeInfer(...) {
1081
+ OrtCustomOp::InferOutputShapeFn = {};
1082
+ }
1083
+ }; // struct OrtLiteCustomStruct
1084
+
1085
+ /////////////////////////// CreateLiteCustomOp ////////////////////////////
1086
+
1087
+ template <typename... Args>
1088
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1089
+ const char* execution_provider,
1090
+ void (*custom_compute_fn)(Args...),
1091
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1092
+ int start_ver = 1,
1093
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1094
+ using LiteOp = OrtLiteCustomFunc<Args...>;
1095
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1096
+ }
1097
+
1098
+ template <typename... Args>
1099
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1100
+ const char* execution_provider,
1101
+ Status (*custom_compute_fn_v2)(Args...),
1102
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1103
+ int start_ver = 1,
1104
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1105
+ using LiteOp = OrtLiteCustomFunc<Args...>;
1106
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1107
+ }
1108
+
1109
+ template <typename CustomOp>
1110
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1111
+ const char* execution_provider,
1112
+ int start_ver = 1,
1113
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1114
+ using LiteOp = OrtLiteCustomStruct<CustomOp>;
1115
+ return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
1116
+ }
1117
+
1118
+ } // namespace Custom
1119
+ } // namespace Ort
v1.23.1/headers/onnxruntime_run_options_config_keys.h ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines RunOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a RunOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 128
14
+ *
15
+ * The string format of a RunOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 1024
17
+ */
18
+
19
+ // Key for enabling shrinkages of user listed device memory arenas.
20
+ // Expects a list of semi-colon separated key value pairs separated by colon in the following format:
21
+ // "device_0:device_id_0;device_1:device_id_1"
22
+ // No white-spaces allowed in the provided list string.
23
+ // Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
24
+ // If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
25
+ // Example usage: "cpu:0;gpu:0" (or) "gpu:0"
26
+ // By default, the value for this key is empty (i.e.) no memory arenas are shrunk
27
+ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
28
+
29
+ // Set to '1' to not synchronize execution providers with CPU at the end of session run.
30
+ // Per default it will be set to '0'
31
+ // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
32
+ static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
33
+
34
+ // Set HTP performance mode for QNN HTP backend before session run.
35
+ // options for HTP performance mode: "burst", "balanced", "default", "high_performance",
36
+ // "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
37
+ // "sustained_high_performance". Default to "default".
38
+ static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
39
+
40
+ // Set HTP performance mode for QNN HTP backend post session run.
41
+ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
42
+
43
+ // Set RPC control latency for QNN HTP backend
44
+ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
45
+
46
+ // Set QNN Lora Config File for apply Lora in QNN context binary
47
+ static const char* const kOrtRunOptionsConfigQnnLoraConfig = "qnn.lora_config";
48
+
49
+ // Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
50
+ // The value should be an integer. If the value is not set, the default value is 0 and
51
+ // ORT session only captures one cuda graph before another capture is requested.
52
+ // If the value is set to -1, cuda graph capture/replay is disabled in that run.
53
+ // User are not expected to set the value to 0 as it is reserved for internal use.
54
+ static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
v1.23.1/headers/onnxruntime_session_options_config_keys.h ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines SessionOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a SessionOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 1024
14
+ *
15
+ * The string format of a SessionOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 2048
17
+ */
18
+
19
+ // Key for disable PrePacking,
20
+ // If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
21
+ static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
22
+
23
+ // A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
24
+ // will be used. Use this to override the usage of env allocators on a per session level.
25
+ static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
26
+
27
+ // Set to 'ORT' (case sensitive) to load an ORT format model.
28
+ // If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
29
+ static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
30
+
31
+ // Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
32
+ // If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
33
+ static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
34
+
35
+ // If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
36
+ // When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
37
+ // but threads in session thread pools follow option changes.
38
+ // When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
39
+ // denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
40
+ // Note that an alternative way not using this option at runtime is to train and export a model without denormals
41
+ // and that's recommended because turning this option on may hurt model accuracy.
42
+ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
43
+
44
+ // It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
45
+ // "0": enable. ORT does fusion logic for QDQ format.
46
+ // "1": disable. ORT doesn't do fusion logic for QDQ format.
47
+ // Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
48
+ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
49
+
50
+ // It controls whether to enable Double QDQ remover and Identical Children Consolidation
51
+ // "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
52
+ // "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
53
+ // Its default value is "0"
54
+ static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
55
+
56
+ // If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
57
+ // completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
58
+ // Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
59
+ // 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
60
+ // other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
61
+ // As such, it's best to test to determine if enabling this works well for your scenario.
62
+ // The default value is "0"
63
+ // Available since version 1.11.
64
+ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
65
+
66
+ // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
67
+ // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
68
+ static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
69
+
70
+ // Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0".
71
+ // CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this.
72
+ static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination";
73
+
74
+ // This setting controls whether to enable AheadOfTime function inlining.
75
+ // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
76
+ // as possible with the help of enabled execution providers.
77
+ // This can reduce the number of function calls and improve performance because it is done before
78
+ // Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
79
+ // one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
80
+ // "0": enable; "1": disable.
81
+ // Its default value is "0".
82
+ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
83
+
84
+ #ifdef ENABLE_TRAINING
85
+ // Specifies a path of the file containing a list of memory optimization configurations.
86
+ // The value should be a string indicating the file path of the config file.
87
+ // The content of the config file is a JSON struct like this:
88
+ // [
89
+ // "Gelu+Cast+:1:0",
90
+ // "Dropout+:1:1"
91
+ // ]
92
+ // Taking the example of "Gelu+Cast+:1:0",
93
+ // > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation
94
+ // output by ORT graph transformations.
95
+ // > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute.
96
+ // > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization,
97
+ // to avoid "oversaving" the memory.
98
+ static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config";
99
+
100
+ // Specifies the config for detecting subgraphs for memory footprint reduction.
101
+ // The value should be a string contains int separated using commas. The default value is "0:0".
102
+ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
103
+ #endif
104
+
105
+ // This setting if set should contain a comma separated list of optimizers names that should be disabled.
106
+ // Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer
107
+ // does not provider runtime benefits, but affects your model loading time you may disable it using this config
108
+ // entry. This option is not enabled in ORT_MINIMAL_BUILD build.
109
+ // A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc
110
+ //
111
+ // Default is an empty string which means no optimizers are disabled.
112
+ static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
113
+
114
+ // It controls whether to run graph optimizations in loop or not.
115
+ //
116
+ // "0": disable. Graph Optimization Loop is disabled.
117
+ // ```
118
+ // Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
119
+ // ^ |
120
+ // | "No Loop" |
121
+ // | |
122
+ // X xxxxxxxxxxx X
123
+ // ```
124
+ // "1": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 4 are applied then
125
+ // the loop will check for any other valid optimization that can happen.
126
+ // ```
127
+ // Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
128
+ // ^ |
129
+ // | "Loop only depending on Level 4" |
130
+ // | |
131
+ // ---------------------------------------------------
132
+ // ```
133
+ // "2": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 2 or above are applied then
134
+ // The loop will check for any other valid optimization that can happen.
135
+ // ```
136
+ // Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
137
+ // ^ |
138
+ // | "Loop" |
139
+ // | |
140
+ // ---------------------------------------------------
141
+ // ```
142
+ // Default value is set to "1".
143
+ static const char* const kOrtSessionOptionsGraphOptimizationsLoopLevel = "session.graph_optimizations_loop_level";
144
+
145
+ // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
146
+ // Using device allocators means the memory allocation is made using malloc/new.
147
+ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
148
+
149
+ // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
150
+ // "0": thread will block if found no job to run
151
+ // "1": thread will spin a number of times before blocking
152
+ // The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise.
153
+ // Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency.
154
+ static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
155
+ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
156
+
157
+ // Key for using model bytes directly for ORT format
158
+ // If a session is created using an input byte array contains the ORT format model data,
159
+ // By default we will copy the model bytes at the time of session creation to ensure the model bytes
160
+ // buffer is valid.
161
+ // Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
162
+ // has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
163
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
164
+
165
+ /// <summary>
166
+ /// Key for using the ORT format model flatbuffer bytes directly for initializers.
167
+ /// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
168
+ /// Requires `session.use_ort_model_bytes_directly` to be true.
169
+ /// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
170
+ /// duration of the InferenceSession.
171
+ /// </summary>
172
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
173
+ "session.use_ort_model_bytes_for_initializers";
174
+
175
+ // This should only be specified when exporting an ORT format model for use on a different platform.
176
+ // If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
177
+ // Available since version 1.11.
178
+ static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
179
+
180
+ // x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
181
+ // To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
182
+ // turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
183
+ // platforms.
184
+ static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
185
+
186
+ // Specifies how minimal build graph optimizations are handled in a full build.
187
+ // These optimizations are at the extended level or higher.
188
+ // Possible values and their effects are:
189
+ // "save": Save runtime optimizations when saving an ORT format model.
190
+ // "apply": Only apply optimizations available in a minimal build.
191
+ // ""/<unspecified>: Apply optimizations available in a full build.
192
+ // Available since version 1.11.
193
+ static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
194
+ "optimization.minimal_build_optimizations";
195
+
196
+ // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
197
+ // order for them to take effect.
198
+
199
+ // Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
200
+ // run by the NNAPI EP.
201
+ // The value should be a ","-delimited list of op types. For example, "Add,Sub".
202
+ // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
203
+ // exclusion, set the value to "".
204
+ static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
205
+
206
+ // Enabling dynamic block-sizing for multithreading.
207
+ // With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
208
+ // N / (num_of_threads * dynamic_block_base)
209
+ // As execution progresses, the size will decrease according to the diminishing residual of N,
210
+ // meaning the task will be distributed in smaller granularity for better parallelism.
211
+ // For some models, it helps to reduce the variance of E2E inference latency and boost performance.
212
+ // The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
213
+ // Available since version 1.11.
214
+ static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
215
+
216
+ // This option allows to decrease CPU usage between infrequent
217
+ // requests and forces any TP threads spinning stop immediately when the last of
218
+ // concurrent Run() call returns.
219
+ // Spinning is restarted on the next Run() call.
220
+ // Applies only to internal thread-pools
221
+ static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
222
+
223
+ // "1": all inconsistencies encountered during shape and type inference
224
+ // will result in failures.
225
+ // "0": in some cases warnings will be logged but processing will continue. The default.
226
+ // May be useful to expose bugs in models.
227
+ static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
228
+
229
+ // "1": every model using a more recent opset than the latest released one will fail
230
+ // "0": the model may or may not work if onnxruntime cannot find an implementation, this option
231
+ // is used for development purpose.
232
+ static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
233
+
234
+ // The file saves configuration for partitioning node among logic streams
235
+ static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
236
+
237
+ // This Option allows setting affinities for intra op threads.
238
+ // Affinity string follows format:
239
+ // logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
240
+ // Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
241
+ // e.g.1,2,3;4,5
242
+ // specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
243
+ // To ease the configuration, an "interval" is also allowed:
244
+ // e.g. 1-8;8-16;17-24
245
+ // orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
246
+ // Note:
247
+ // 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
248
+ // is started and managed by the calling app;
249
+ // 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
250
+ // an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
251
+ // Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
252
+ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
253
+
254
+ // This option will dump out the model to assist debugging any issues with layout transformation,
255
+ // and is primarily intended for developer usage. It is only relevant if an execution provider that requests
256
+ // NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
257
+ //
258
+ // Default is off. Set to "1" to enable.
259
+ //
260
+ // If modified by layout transformation the model will be dumped after these steps:
261
+ // 1) insertion of the layout transformation Transpose nodes
262
+ // 2) after those are optimized using the transpose optimizer,
263
+ // 3) after the L1 transformers are applied to the updated graph.
264
+ // The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
265
+ static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
266
+
267
+ // Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
268
+ // assigned (i.e., "fallback") to the CPU EP by default.
269
+ //
270
+ // This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
271
+ // If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
272
+ // fully support all of the nodes in the graph.
273
+ //
274
+ // It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
275
+ // will also fail with an error.
276
+ //
277
+ // Option values:
278
+ // - "0": CPU EP fallback is not disabled. [DEFAULT]
279
+ // - "1": CPU EP fallback is disabled.
280
+ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
281
+
282
+ // Use this config when serializing a large model after optimization to specify an external initializers file
283
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
284
+ "session.optimized_model_external_initializers_file_name";
285
+
286
+ // Use this config to control the minimum size of the initializer when externalizing it during serialization
287
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
288
+ "session.optimized_model_external_initializers_min_size_in_bytes";
289
+
290
+ // When loading model from memory buffer and the model has external initializers
291
+ // Use this config to set the external data file folder path
292
+ // All external data files should be in the same folder
293
+ static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPath =
294
+ "session.model_external_initializers_file_folder_path";
295
+
296
+ // Use this config when saving pre-packed constant initializers to an external data file.
297
+ // This allows you to memory map pre-packed initializers on model load and leave it to
298
+ // to the OS the amount of memory consumed by the pre-packed initializers. Otherwise,
299
+ // pre-packed data resides on the heap.
300
+ //
301
+ // - "0": Default is not save pre-packed initializers to a data file.
302
+ // - "1": Save pre-packed constant initializers to an external data file.
303
+ // Sample usage: sess_options.add_session_config_entry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1")
304
+ static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
305
+ "session.save_external_prepacked_constant_initializers";
306
+
307
+ // Use this config when you want to collect memory stats for each node in the graph.
308
+ // The file format is a CSV file with the following columns:
309
+ // The file will be created if it does not exist, and will be overwritten if it does.
310
+ //
311
+ // The content of the file can be used to estimate memory requirements at run time including
312
+ // the temporary allocations. This operation is preferably done on a CPU device, as the model may exceed
313
+ // device memory limits in constrained environments. When enabling this option, it is important to disable
314
+ // memory patterns, as they tend to allocate large blocks to avoid fragmentation and accommodate needs of multiple
315
+ // kernels. Memory patterns may make it difficult to allocate on a device with limited memory.
316
+ //
317
+ // The collected stats then can be used to partition the graph among the devices in a way that only the
318
+ // required memory is allocated on each device.
319
+ //
320
+ // node_name, initializers_memory, dynamic_outputs_sizes, temp_allocations_size
321
+ //
322
+ // - "full path to file": there is not a default for this option. If the file can not be opened for writing, an error will be returned.
323
+ static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "session.collect_node_memory_stats_to_file";
324
+
325
+ /// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats"
326
+ /// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. `limit` is optional and when absent
327
+ /// the provider may attempt to figure out the memory available automatically.
328
+ /// The setting with no limit is expected to look like: ",file name for collected stats"
329
+ /// The EP will place nodes on device "file name" :
330
+ /// this file is expected to be found at the same folder with the model. The file contains
331
+ /// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above)
332
+ static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings =
333
+ "session.resource_cuda_partitioning_settings";
334
+
335
+ // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
336
+ // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
337
+ // "0": disable. (default)
338
+ // "1": enable.
339
+ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
340
+
341
+ // Specify the file path for the Onnx model which has EP context.
342
+ // Default to original_file_name_ctx.onnx if not specified
343
+ // Folder is not a valid option
344
+ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
345
+
346
+ // Flag to specify whether to dump the EP context into the Onnx model.
347
+ // "0": dump the EP context into separate file, keep the file name in the Onnx model. (default).
348
+ // "1": dump the EP context into the Onnx model.
349
+ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
350
+
351
+ // Specify the EPContext node name prefix to make it unique
352
+ // in case user need to merge/connect multiple EPContext nodes in one model
353
+ static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix";
354
+
355
+ // Share EP related resources across sessions
356
+ static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts";
357
+
358
+ // Stop to share EP related resources across sessions from then on
359
+ static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts";
360
+
361
+ // Used only for context model generation.
362
+ // This configuration is used when some nodes are partitioned on the CPU EP and those nodes have external initializers.
363
+ // When generating the EP context model, the new model should not rely on the old external data file used by the source ONNX model.
364
+ // Use this setting when dumping the EP context model with an external initializers file.
365
+ // If specified, all initializers will be placed inside the external data file.
366
+ // Otherwise, all initializers will be embedded inside the generated ONNX file.
367
+ // By default, this option is not set, meaning all initializers will be included within the ONNX file.
368
+ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName =
369
+ "ep.context_model_external_initializers_file_name";
370
+
371
+ // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
372
+ // Option values:
373
+ // - "0": Gemm FastMath mode is not enabled. [DEFAULT]
374
+ // - "1": Gemm FastMath mode is enabled.
375
+ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
376
+
377
+ // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
378
+ // Refer to MatMulNBits op schema for more details.
379
+ // If not provided, default is 4.
380
+ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
381
+
382
+ // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
383
+ // Meant to be used with SetEpDynamicOptions
384
+ // Specify the type of workload for this session.
385
+ // "Default": OS determines the scheduling priority and processor performance to service this workload. [Default]
386
+ // "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
387
+ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
388
+
389
+ // Disables model compilation during session initialization.
390
+ //
391
+ // If this option is set to "1", inference session creation will fail with error code ORT_MODEL_REQUIRES_COMPILATION
392
+ // if compilation is required to run the model on any Execution Provider added to the session.
393
+ // Only the following kinds of models are valid when this option is set to "1":
394
+ // - Pre-compiled models that have EPContext nodes for the compiling Execution Providers in the session.
395
+ // - Non-compiled models that run only on non-compiling Execution Providers, like CPU EP.
396
+ //
397
+ // See \href https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for details about
398
+ // compiled models with EPContext nodes.
399
+ //
400
+ // Option values:
401
+ // - "0": EP compile is not disabled. [DEFAULT]
402
+ // - "1": EP compile is disabled.
403
+ static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile";
404
+
405
+ // Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION.
406
+ // "0": Allow execution with suboptimal performance. [DEFAULT]
407
+ // "1": Fail session creation to require recompilation for optimal performance.
408
+ // Note: UNSUPPORTED models always fail regardless of this setting.
409
+ static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel =
410
+ "session.fail_on_suboptimal_compiled_model";
411
+
412
+ // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
413
+ // Meant to be used with SetEpDynamicOptions
414
+ // options for HTP performance mode: "burst", "balanced", "default", "high_performance",
415
+ // "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
416
+ // "sustained_high_performance". Default to "default".
417
+ static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode";
v1.23.1/jni/arm64-v8a/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:629d677b15742620b008fdeafafcf825e6635de5e650f2728b98cc7c659fbbe5
3
+ size 19343456
v1.23.1/jni/arm64-v8a/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ebb4db00b87e243a6e7f401e30cac6b4172a921a59ba44bb72a53eba55e6920
3
+ size 100648
v1.23.1/jni/armeabi-v7a/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c6a7a9d851346abaf9911e6fa1bbb4433ddeb0f5c9d17429cf59942707613bb
3
+ size 13988872
v1.23.1/jni/armeabi-v7a/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0ce9f1da045c17666540e61d0bd0861d9a6fe9044486ea78c1c34d6a8f9ecc9
3
+ size 73680
v1.23.1/jni/x86/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a78e94433c3abb77a9c7ddf2b8b74e94f7af4e6073b3288a32d6d98a68c38ed
3
+ size 22757348
v1.23.1/jni/x86/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55bd2fc2473495913e3be73b706b8e0bd4138b561c80cacda52e487b2f47e30e
3
+ size 84700
v1.23.1/jni/x86_64/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b830967d557bd5ecf90de8357ffac3d6ec8a27f70d48d2e954e1809b0612165
3
+ size 23176928
v1.23.1/jni/x86_64/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfd2404d809a7339bd4ab7719af93c5786d003aff2ea1a1f312b6bc3b1d64366
3
+ size 90728