Commit
·
11ce2d8
1
Parent(s):
bc4ec38
add 1.23.1
Browse files- v1.23.1/headers/cpu_provider_factory.h +19 -0
- v1.23.1/headers/nnapi_provider_factory.h +62 -0
- v1.23.1/headers/onnxruntime_c_api.h +0 -0
- v1.23.1/headers/onnxruntime_cxx_api.h +0 -0
- v1.23.1/headers/onnxruntime_cxx_inline.h +0 -0
- v1.23.1/headers/onnxruntime_ep_c_api.h +988 -0
- v1.23.1/headers/onnxruntime_ep_device_ep_metadata_keys.h +18 -0
- v1.23.1/headers/onnxruntime_float16.h +535 -0
- v1.23.1/headers/onnxruntime_lite_custom_op.h +1119 -0
- v1.23.1/headers/onnxruntime_run_options_config_keys.h +54 -0
- v1.23.1/headers/onnxruntime_session_options_config_keys.h +417 -0
- v1.23.1/jni/arm64-v8a/libonnxruntime.so +3 -0
- v1.23.1/jni/arm64-v8a/libonnxruntime4j_jni.so +3 -0
- v1.23.1/jni/armeabi-v7a/libonnxruntime.so +3 -0
- v1.23.1/jni/armeabi-v7a/libonnxruntime4j_jni.so +3 -0
- v1.23.1/jni/x86/libonnxruntime.so +3 -0
- v1.23.1/jni/x86/libonnxruntime4j_jni.so +3 -0
- v1.23.1/jni/x86_64/libonnxruntime.so +3 -0
- v1.23.1/jni/x86_64/libonnxruntime4j_jni.so +3 -0
v1.23.1/headers/cpu_provider_factory.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#include "onnxruntime_c_api.h"
|
| 5 |
+
|
| 6 |
+
#ifdef __cplusplus
|
| 7 |
+
extern "C" {
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
/**
|
| 11 |
+
* \param use_arena zero: false. non-zero: true.
|
| 12 |
+
*/
|
| 13 |
+
ORT_EXPORT
|
| 14 |
+
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
| 15 |
+
ORT_ALL_ARGS_NONNULL;
|
| 16 |
+
|
| 17 |
+
#ifdef __cplusplus
|
| 18 |
+
}
|
| 19 |
+
#endif
|
v1.23.1/headers/nnapi_provider_factory.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "onnxruntime_c_api.h"
|
| 6 |
+
|
| 7 |
+
// NNAPIFlags are bool options we want to set for NNAPI EP
|
| 8 |
+
// This enum is defined as bit flags, and cannot have negative value
|
| 9 |
+
// To generate an uint32_t nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
|
| 10 |
+
// uint32_t nnapi_flags = 0;
|
| 11 |
+
// nnapi_flags |= NNAPI_FLAG_USE_FP16;
|
| 12 |
+
enum NNAPIFlags {
|
| 13 |
+
NNAPI_FLAG_USE_NONE = 0x000,
|
| 14 |
+
|
| 15 |
+
// Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
|
| 16 |
+
NNAPI_FLAG_USE_FP16 = 0x001,
|
| 17 |
+
|
| 18 |
+
// Use NCHW layout in NNAPI EP, this is only available after Android API level 29
|
| 19 |
+
// Please note for now, NNAPI perform worse using NCHW compare to using NHWC
|
| 20 |
+
NNAPI_FLAG_USE_NCHW = 0x002,
|
| 21 |
+
|
| 22 |
+
// Prevent NNAPI from using CPU devices.
|
| 23 |
+
//
|
| 24 |
+
// NNAPI is more efficient using GPU or NPU for execution, and NNAPI might fall back to its own CPU implementation
|
| 25 |
+
// for operations not supported by GPU/NPU. The CPU implementation of NNAPI (which is called nnapi-reference)
|
| 26 |
+
// might be less efficient than the optimized versions of the operation of ORT. It might be advantageous to disable
|
| 27 |
+
// the NNAPI CPU fallback and handle execution using ORT kernels.
|
| 28 |
+
//
|
| 29 |
+
// For some models, if NNAPI would use CPU to execute an operation, and this flag is set, the execution of the
|
| 30 |
+
// model may fall back to ORT kernels.
|
| 31 |
+
//
|
| 32 |
+
// This option is only available after Android API level 29, and will be ignored for Android API level 28-
|
| 33 |
+
//
|
| 34 |
+
// For NNAPI device assignments, see https://developer.android.com/ndk/guides/neuralnetworks#device-assignment
|
| 35 |
+
// For NNAPI CPU fallback, see https://developer.android.com/ndk/guides/neuralnetworks#cpu-fallback
|
| 36 |
+
//
|
| 37 |
+
// Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
|
| 38 |
+
// and NNAPI_FLAG_CPU_ONLY flags are set
|
| 39 |
+
NNAPI_FLAG_CPU_DISABLED = 0x004,
|
| 40 |
+
|
| 41 |
+
// Using CPU only in NNAPI EP, this may decrease the perf but will provide
|
| 42 |
+
// reference output value without precision loss, which is useful for validation
|
| 43 |
+
//
|
| 44 |
+
// Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
|
| 45 |
+
// and NNAPI_FLAG_CPU_ONLY flags are set
|
| 46 |
+
NNAPI_FLAG_CPU_ONLY = 0x008,
|
| 47 |
+
|
| 48 |
+
// Keep NNAPI_FLAG_LAST at the end of the enum definition
|
| 49 |
+
// And assign the last NNAPIFlag to it
|
| 50 |
+
NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY,
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
#ifdef __cplusplus
|
| 54 |
+
extern "C" {
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi,
|
| 58 |
+
_In_ OrtSessionOptions* options, uint32_t nnapi_flags);
|
| 59 |
+
|
| 60 |
+
#ifdef __cplusplus
|
| 61 |
+
}
|
| 62 |
+
#endif
|
v1.23.1/headers/onnxruntime_c_api.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
v1.23.1/headers/onnxruntime_cxx_api.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
v1.23.1/headers/onnxruntime_cxx_inline.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
v1.23.1/headers/onnxruntime_ep_c_api.h
ADDED
|
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
// Do not include this file directly. Please include "onnxruntime_c_api.h" instead.
|
| 5 |
+
|
| 6 |
+
#ifdef __cplusplus
|
| 7 |
+
extern "C" {
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
ORT_RUNTIME_CLASS(Ep);
|
| 11 |
+
ORT_RUNTIME_CLASS(EpFactory);
|
| 12 |
+
ORT_RUNTIME_CLASS(EpGraphSupportInfo);
|
| 13 |
+
ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice
|
| 14 |
+
ORT_RUNTIME_CLASS(NodeComputeContext);
|
| 15 |
+
|
| 16 |
+
ORT_RUNTIME_CLASS(DataTransferImpl);
|
| 17 |
+
ORT_RUNTIME_CLASS(SyncNotificationImpl);
|
| 18 |
+
ORT_RUNTIME_CLASS(SyncStreamImpl);
|
| 19 |
+
|
| 20 |
+
// struct that an EP implements for IDataTransfer to copy between devices it uses and CPU
|
| 21 |
+
struct OrtDataTransferImpl {
|
| 22 |
+
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
|
| 23 |
+
|
| 24 |
+
/** \brief Release the OrtDataTransferImpl instance.
|
| 25 |
+
*
|
| 26 |
+
* This is called by ORT when the OrtDataTransferImpl instance is no longer needed.
|
| 27 |
+
* The implementation should release any resources held by the instance.
|
| 28 |
+
*
|
| 29 |
+
* \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
|
| 30 |
+
*
|
| 31 |
+
* \since Version 1.23.
|
| 32 |
+
*/
|
| 33 |
+
ORT_API_T(void, Release, _In_ OrtDataTransferImpl* this_ptr);
|
| 34 |
+
|
| 35 |
+
/** \brief Check if the implementation can copy between the source and destination memory devices.
|
| 36 |
+
*
|
| 37 |
+
* \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
|
| 38 |
+
* \param[in] src_memory_device Source OrtMemoryDevice to copy from.
|
| 39 |
+
* \param[in] dst_memory_device Destination OrtMemoryDevice to copy to.
|
| 40 |
+
* \return True if the implementation can copy between the devices.
|
| 41 |
+
*
|
| 42 |
+
* \since Version 1.23.
|
| 43 |
+
*/
|
| 44 |
+
ORT_API_T(bool, CanCopy, _In_ const OrtDataTransferImpl* this_ptr,
|
| 45 |
+
_In_ const OrtMemoryDevice* src_memory_device, _In_ const OrtMemoryDevice* dst_memory_device);
|
| 46 |
+
|
| 47 |
+
/** \brief Copy tensors from src_tensors to dst_tensors using the provided streams.
|
| 48 |
+
*
|
| 49 |
+
* The implementation can use the provided streams to perform asynchronous copies if supported.
|
| 50 |
+
* If a stream is not available, the copy is performed synchronously.
|
| 51 |
+
*
|
| 52 |
+
* \param[in] this_ptr Pointer to the OrtDataTransferImpl instance.
|
| 53 |
+
* \param[in] src_tensors Array of source OrtValue pointers to copy from.
|
| 54 |
+
* \param[in] dst_tensors Array of destination OrtValue pointers to copy to.
|
| 55 |
+
* \param[in] streams Array of OrtSyncStream pointers for the copy operations, if the execution provider is stream
|
| 56 |
+
* aware. nullptr if it is not.
|
| 57 |
+
* \param[in] num_tensors Number of tensors to copy.
|
| 58 |
+
*
|
| 59 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 60 |
+
*
|
| 61 |
+
* \since Version 1.23.
|
| 62 |
+
*/
|
| 63 |
+
ORT_API2_STATUS(CopyTensors, _In_ OrtDataTransferImpl* this_ptr,
|
| 64 |
+
_In_reads_(num_tensors) const OrtValue** src_tensors,
|
| 65 |
+
_In_reads_(num_tensors) OrtValue** dst_tensors,
|
| 66 |
+
_In_reads_(num_tensors) OrtSyncStream** streams,
|
| 67 |
+
_In_ size_t num_tensors);
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
/** \brief Struct that an EP implements for Stream Notifications.
|
| 71 |
+
*
|
| 72 |
+
* \since Version 1.23.
|
| 73 |
+
*/
|
| 74 |
+
struct OrtSyncNotificationImpl {
|
| 75 |
+
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
|
| 76 |
+
|
| 77 |
+
/** \brief Release the OrtSyncNotificationImpl instance.
|
| 78 |
+
*
|
| 79 |
+
* This is called by ORT when the OrtSyncNotificationImpl instance is no longer needed.
|
| 80 |
+
* The implementation should release any resources held by the instance.
|
| 81 |
+
*
|
| 82 |
+
* \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
|
| 83 |
+
*
|
| 84 |
+
* \since Version 1.23.
|
| 85 |
+
*/
|
| 86 |
+
ORT_API_T(void, Release, _In_ OrtSyncNotificationImpl* this_ptr);
|
| 87 |
+
|
| 88 |
+
/** \brief Called by ORT to activate the notification.
|
| 89 |
+
*
|
| 90 |
+
* \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
|
| 91 |
+
*
|
| 92 |
+
* \since Version 1.23.
|
| 93 |
+
*/
|
| 94 |
+
ORT_API2_STATUS(Activate, _In_ OrtSyncNotificationImpl* this_ptr);
|
| 95 |
+
|
| 96 |
+
/** \brief Wait for a device to device operation to complete.
|
| 97 |
+
*
|
| 98 |
+
* \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
|
| 99 |
+
* \param[in] stream The OrtSyncStream instance that will wait on this notification to be activated.
|
| 100 |
+
*
|
| 101 |
+
* \since Version 1.23.
|
| 102 |
+
*/
|
| 103 |
+
ORT_API2_STATUS(WaitOnDevice, _In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* consumer_stream);
|
| 104 |
+
|
| 105 |
+
/** \brief Wait for a device to host operation to complete.
|
| 106 |
+
*
|
| 107 |
+
* \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance.
|
| 108 |
+
*
|
| 109 |
+
* \since Version 1.23.
|
| 110 |
+
*/
|
| 111 |
+
ORT_API2_STATUS(WaitOnHost, _In_ OrtSyncNotificationImpl* this_ptr);
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
/** \brief Struct that an EP implements if it wishes to implement Stream support.
|
| 115 |
+
*
|
| 116 |
+
* This struct provides the overrides for onnxruntime::Stream's virtual methods.
|
| 117 |
+
*
|
| 118 |
+
* \since Version 1.23.
|
| 119 |
+
*/
|
| 120 |
+
struct OrtSyncStreamImpl {
|
| 121 |
+
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
|
| 122 |
+
|
| 123 |
+
/** \brief Release the OrtSyncStreamImpl instance.
|
| 124 |
+
*
|
| 125 |
+
* This is called by ORT when the OrtSyncStreamImpl instance is no longer needed.
|
| 126 |
+
* The implementation should release any resources held by the instance.
|
| 127 |
+
*
|
| 128 |
+
* \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
|
| 129 |
+
*
|
| 130 |
+
* \since Version 1.23.
|
| 131 |
+
*/
|
| 132 |
+
ORT_API_T(void, Release, _In_ OrtSyncStreamImpl* this_ptr);
|
| 133 |
+
|
| 134 |
+
/** \brief Get the handle of the stream.
|
| 135 |
+
*
|
| 136 |
+
* This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams.
|
| 137 |
+
*
|
| 138 |
+
* \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
|
| 139 |
+
* \return The handle of the stream.
|
| 140 |
+
*
|
| 141 |
+
* \since Version 1.23.
|
| 142 |
+
*/
|
| 143 |
+
ORT_API_T(void*, GetHandle, _In_ OrtSyncStreamImpl* this_ptr);
|
| 144 |
+
|
| 145 |
+
/** \brief Create an OrtSyncNotificationImpl for the OrtSyncStreamImpl instance.
|
| 146 |
+
*
|
| 147 |
+
* \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance
|
| 148 |
+
* \param[out] notification The new OrtSyncNotificationImpl instance.
|
| 149 |
+
*
|
| 150 |
+
* \since Version 1.23.
|
| 151 |
+
*/
|
| 152 |
+
ORT_API2_STATUS(CreateNotification, _In_ OrtSyncStreamImpl* this_ptr,
|
| 153 |
+
_Outptr_ OrtSyncNotificationImpl** notification);
|
| 154 |
+
|
| 155 |
+
/** \brief Flush the stream.
|
| 156 |
+
*
|
| 157 |
+
* This is called by ORT to flush the stream, ensuring that all operations submitted to the stream are completed.
|
| 158 |
+
*
|
| 159 |
+
* \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
|
| 160 |
+
*
|
| 161 |
+
* \since Version 1.23.
|
| 162 |
+
*/
|
| 163 |
+
ORT_API2_STATUS(Flush, _In_ OrtSyncStreamImpl* this_ptr);
|
| 164 |
+
|
| 165 |
+
/** \brief Notify the stream that a session run has ended.
|
| 166 |
+
*
|
| 167 |
+
* This is called by ORT to notify the stream that a session run has ended, allowing the stream to perform any
|
| 168 |
+
* necessary cleanup or finalization.
|
| 169 |
+
*
|
| 170 |
+
* \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance.
|
| 171 |
+
*
|
| 172 |
+
* \since Version 1.23.
|
| 173 |
+
*/
|
| 174 |
+
ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr);
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
struct OrtNodeFusionOptions;
|
| 178 |
+
typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;
|
| 179 |
+
|
| 180 |
+
struct OrtNodeComputeInfo;
|
| 181 |
+
typedef struct OrtNodeComputeInfo OrtNodeComputeInfo;
|
| 182 |
+
|
| 183 |
+
/**
|
| 184 |
+
* \brief The OrtNodeFusionOptions struct specifies options for fusing nodes supported by an execution provider.
|
| 185 |
+
*
|
| 186 |
+
* Refer to OrtEpApi::EpGraphSupportInfo_AddNodesToFuse.
|
| 187 |
+
*
|
| 188 |
+
* \since Version 1.23.
|
| 189 |
+
*/
|
| 190 |
+
struct OrtNodeFusionOptions {
|
| 191 |
+
/** \brief The ONNX Runtime version the OrtNodeFusionOptions was compiled with.
|
| 192 |
+
*
|
| 193 |
+
* Implementation should set to ORT_API_VERSION.
|
| 194 |
+
* ORT will use this to ensure it does not use members that were not available when the EP library was compiled.
|
| 195 |
+
*
|
| 196 |
+
* \since Version 1.23.
|
| 197 |
+
*/
|
| 198 |
+
uint32_t ort_version_supported;
|
| 199 |
+
|
| 200 |
+
/** \brief If set to true, specify that the execution provider does not require ONNX Runtime to provide constant
|
| 201 |
+
* initializers as inputs to the fused node during model inference. This is used when the execution
|
| 202 |
+
* provider saves a copy of constant initializers, and allows ONNX Runtime to release constant initializers that
|
| 203 |
+
* are not used by any execution provider.
|
| 204 |
+
*
|
| 205 |
+
* If not specified, defaults to false. That is, ONNX Runtime provides constant initializers as inputs to
|
| 206 |
+
* the fused node by default.
|
| 207 |
+
*
|
| 208 |
+
* \since Version 1.23.
|
| 209 |
+
*/
|
| 210 |
+
bool drop_constant_initializers;
|
| 211 |
+
|
| 212 |
+
// const OrtNode* fused_node_schema;
|
| 213 |
+
};
|
| 214 |
+
|
| 215 |
+
/**
|
| 216 |
+
* \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute
|
| 217 |
+
* function for a compiled OrtGraph instance.
|
| 218 |
+
* \since Version 1.23.
|
| 219 |
+
*/
|
| 220 |
+
struct OrtNodeComputeInfo {
|
| 221 |
+
/** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with.
|
| 222 |
+
*
|
| 223 |
+
* Implementation should set to ORT_API_VERSION.
|
| 224 |
+
* ORT will use this to ensure it does not call functions that were not available when the EP library was compiled.
|
| 225 |
+
*
|
| 226 |
+
* \since Version 1.23.
|
| 227 |
+
*/
|
| 228 |
+
uint32_t ort_version_supported;
|
| 229 |
+
|
| 230 |
+
/** \brief Creates an opaque compute state object that is then passed to the Compute() function during inference.
|
| 231 |
+
* \param[in] this_ptr The OrtNodeComputeInfo instance.
|
| 232 |
+
* \param[in] compute_context OrtNodeComputeContext instance that contains compiled/fused node's name and host
|
| 233 |
+
* memory allocation functions. Can optionally be used to build the compute state.
|
| 234 |
+
* \param[out] compute_state Output parameter that is assigned the opaque computation state. ONNX Runtime calls
|
| 235 |
+
* ReleaseState() (after calling Compute()) to allow the implementer to release the
|
| 236 |
+
* compute state.
|
| 237 |
+
*
|
| 238 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 239 |
+
*
|
| 240 |
+
* \since Version 1.23.
|
| 241 |
+
*/
|
| 242 |
+
OrtStatus*(ORT_API_CALL* CreateState)(_In_ OrtNodeComputeInfo* this_ptr,
|
| 243 |
+
_In_ OrtNodeComputeContext* compute_context,
|
| 244 |
+
_Outptr_ void** compute_state);
|
| 245 |
+
|
| 246 |
+
/** \brief Computation function called to execute the fused node compiled by an OrtEp instance.
|
| 247 |
+
* \param[in] this_ptr The OrtNodeComputeInfo instance.
|
| 248 |
+
* \param[in] compute_state The opaque computation state returned by CreateState().
|
| 249 |
+
* \param[in] kernel_context The OrtKernelContext instance used to access inputs/outputs.
|
| 250 |
+
*
|
| 251 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 252 |
+
*
|
| 253 |
+
* \since Version 1.23.
|
| 254 |
+
*/
|
| 255 |
+
OrtStatus*(ORT_API_CALL* Compute)(_In_ OrtNodeComputeInfo* this_ptr, _In_ void* compute_state,
|
| 256 |
+
_In_ OrtKernelContext* kernel_context);
|
| 257 |
+
|
| 258 |
+
/** \brief Releases the compute state returned by CreateState().
|
| 259 |
+
* \param[in] this_ptr The OrtNodeComputeInfo instance.
|
| 260 |
+
* \param[inout] compute_state The opaque compute state returned by CreateState().
|
| 261 |
+
*
|
| 262 |
+
* \since Version 1.23.
|
| 263 |
+
*/
|
| 264 |
+
void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state);
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
struct OrtEpApi {
|
| 268 |
+
/** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice.
|
| 269 |
+
* \param[in] ep_factory Execution provider factory that is creating the instance.
|
| 270 |
+
* \param[in] hardware_device Hardware device that the EP can utilize.
|
| 271 |
+
* \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
|
| 272 |
+
* during execution provider selection and passed to CreateEp.
|
| 273 |
+
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
|
| 274 |
+
* \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
|
| 275 |
+
* to the Session configuration options if the execution provider is selected.
|
| 276 |
+
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
|
| 277 |
+
* \param ep_device OrtExecutionDevice that is created.
|
| 278 |
+
*
|
| 279 |
+
* \since Version 1.22.
|
| 280 |
+
*/
|
| 281 |
+
ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
|
| 282 |
+
_In_ const OrtHardwareDevice* hardware_device,
|
| 283 |
+
_In_opt_ const OrtKeyValuePairs* ep_metadata,
|
| 284 |
+
_In_opt_ const OrtKeyValuePairs* ep_options,
|
| 285 |
+
_Out_ OrtEpDevice** ep_device);
|
| 286 |
+
|
| 287 |
+
ORT_CLASS_RELEASE(EpDevice);
|
| 288 |
+
|
| 289 |
+
/** \brief Specify nodes that are supported by an OrtEp and should be fused into one node.
|
| 290 |
+
*
|
| 291 |
+
* Because the nodes will be fused into one "fused node", there must not exist an unsupported node in
|
| 292 |
+
* a path between two of the provided nodes. Otherwise, the graph will become invalid.
|
| 293 |
+
*
|
| 294 |
+
* This function can be called multiple times. A subsequent call to this function will force the next set of
|
| 295 |
+
* nodes to be fused into a different node.
|
| 296 |
+
*
|
| 297 |
+
* \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes.
|
| 298 |
+
* \param[in] nodes Array of nodes supported by the EP that should be fused/compiled.
|
| 299 |
+
* \param[in] num_nodes The number of supported nodes.
|
| 300 |
+
* \param[in] node_fusion_options Optional node fusion options. Ignored if set to NULL.
|
| 301 |
+
*
|
| 302 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 303 |
+
*
|
| 304 |
+
* \since Version 1.23.
|
| 305 |
+
*/
|
| 306 |
+
ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info,
|
| 307 |
+
_In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes,
|
| 308 |
+
_In_opt_ const OrtNodeFusionOptions* node_fusion_options);
|
| 309 |
+
|
| 310 |
+
/** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel.
|
| 311 |
+
*
|
| 312 |
+
* \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported node.
|
| 313 |
+
* \param[in] node The supported OrtNode instance.
|
| 314 |
+
*
|
| 315 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 316 |
+
*
|
| 317 |
+
* \since Version 1.23.
|
| 318 |
+
*/
|
| 319 |
+
ORT_API2_STATUS(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info,
|
| 320 |
+
_In_ const OrtNode* node);
|
| 321 |
+
|
| 322 |
+
/** \brief Query a OrtNodeComputeContext for the name of the node that encapsulates the compiled/fused node.
|
| 323 |
+
*
|
| 324 |
+
* Used in OrtNodeComputeInfo::CreateComputeState().
|
| 325 |
+
*
|
| 326 |
+
* \param[in] context The OrtNodeComputeContext instance to query.
|
| 327 |
+
* \return The node's name.
|
| 328 |
+
*
|
| 329 |
+
* \note Returned string is owned by ORT and valid only while OrtNodeComputeInfo::CreateComputeState() is called.
|
| 330 |
+
*
|
| 331 |
+
* \since Version 1.23.
|
| 332 |
+
*/
|
| 333 |
+
ORT_API_T(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context);
|
| 334 |
+
|
| 335 |
+
/** \brief Register an allocator with the OrtEpDevice.
|
| 336 |
+
*
|
| 337 |
+
* This allows an EP to provide OrtMemoryInfo for DEFAULT and HOST_ACCESSIBLE memory type as needed.
|
| 338 |
+
* The registered values will be used in calls to OrtEpFactory::CreateAllocator to ensure the required allocator/s
|
| 339 |
+
* are available for EP usage.
|
| 340 |
+
*
|
| 341 |
+
* Multiple calls for the same entry type will replace a previous entry.
|
| 342 |
+
*
|
| 343 |
+
* Available entries:
|
| 344 |
+
* - OrtDeviceAllocator with type of OrtDeviceMemoryType_DEFAULT
|
| 345 |
+
* - OrtDeviceAllocator with type of OrtDeviceMemoryType_HOST_ACCESSIBLE
|
| 346 |
+
* - OrtReadOnlyAllocator with type of OrtDeviceMemoryType_DEFAULT
|
| 347 |
+
* - if provided this allocator will only be used to copy initializers to the device the EP uses.
|
| 348 |
+
* ORT will use the OrtDeviceAllocator if not provided.
|
| 349 |
+
*
|
| 350 |
+
* \param[in] ep_device The OrtEpDevice instance to register the OrtMemoryInfo with.
|
| 351 |
+
* \param[in] allocator_memory_info The OrtMemoryInfo information for the allocator.
|
| 352 |
+
*
|
| 353 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 354 |
+
*
|
| 355 |
+
* \since Version 1.23.
|
| 356 |
+
*/
|
| 357 |
+
ORT_API2_STATUS(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device,
|
| 358 |
+
_In_ const OrtMemoryInfo* allocator_memory_info);
|
| 359 |
+
|
| 360 |
+
/** \brief Get the OrtMemoryDevice from an OrtMemoryInfo instance.
|
| 361 |
+
*
|
| 362 |
+
* This is required for OrtDataTransferImpl (which implements onnxruntime::IDataTransfer) where the OrtMemoryDevice
|
| 363 |
+
* is used in the CanCopy and CopyTensors functions.
|
| 364 |
+
*
|
| 365 |
+
* \param[in] memory_info The OrtMemoryInfo instance to get the memory device from.
|
| 366 |
+
* \return The OrtMemoryDevice associated with the OrtMemoryInfo instance.
|
| 367 |
+
*
|
| 368 |
+
* \since Version 1.23.
|
| 369 |
+
*/
|
| 370 |
+
ORT_API_T(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info);
|
| 371 |
+
|
| 372 |
+
/** \brief Get the OrtMemoryDevice from an OrtValue instance if it contains a Tensor.
|
| 373 |
+
*
|
| 374 |
+
* \param[in] value The OrtValue instance to get the memory device from.
|
| 375 |
+
* \return Memory device if OrtValue contains a Tensor, nullptr otherwise.
|
| 376 |
+
*
|
| 377 |
+
* \since Version 1.23.
|
| 378 |
+
*/
|
| 379 |
+
ORT_API_T(const OrtMemoryDevice*, Value_GetMemoryDevice, _In_ const OrtValue* value);
|
| 380 |
+
|
| 381 |
+
/** \brief Compare two OrtMemoryDevice instances for equality.
|
| 382 |
+
*
|
| 383 |
+
* This is used to check if two memory devices are the same.
|
| 384 |
+
* Used to implement DataTransferImpl::CanCopy.
|
| 385 |
+
*
|
| 386 |
+
* \param[in] a The first OrtMemoryDevice instance to compare.
|
| 387 |
+
* \param[in] b The second OrtMemoryDevice instance to compare.
|
| 388 |
+
* \return True if the two OrtMemoryDevice instances are equal, false otherwise.
|
| 389 |
+
*
|
| 390 |
+
* \since Version 1.23.
|
| 391 |
+
*/
|
| 392 |
+
ORT_API_T(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b);
|
| 393 |
+
|
| 394 |
+
/** \brief Get the OrtMemoryInfoDeviceType value from an OrtMemoryDevice instance.
|
| 395 |
+
*
|
| 396 |
+
* \param[in] memory_device OrtMemoryDevice instance.
|
| 397 |
+
* \return The OrtMemoryInfoDeviceType value.
|
| 398 |
+
*
|
| 399 |
+
* \since Version 1.23.
|
| 400 |
+
*/
|
| 401 |
+
ORT_API_T(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device);
|
| 402 |
+
|
| 403 |
+
/** \brief Get the OrtDeviceMemoryType value from an OrtMemoryDevice instance.
|
| 404 |
+
*
|
| 405 |
+
* \param[in] memory_device OrtMemoryDevice instance.
|
| 406 |
+
* \return The OrtDeviceMemoryType value.
|
| 407 |
+
*
|
| 408 |
+
* \since Version 1.23.
|
| 409 |
+
*/
|
| 410 |
+
ORT_API_T(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device);
|
| 411 |
+
|
| 412 |
+
/** \brief Get the vendor ID from an OrtMemoryDevice instance.
|
| 413 |
+
*
|
| 414 |
+
* The vendor ID is used to identify the vendor of the device, and is typically set to the PCI vendor ID.
|
| 415 |
+
*
|
| 416 |
+
* If the device is not vendor specific (e.g. CPU memory) the vendor ID is set to 0.
|
| 417 |
+
*
|
| 418 |
+
* \param[in] memory_device OrtMemoryDevice instance.
|
| 419 |
+
* \return The vendor ID value.
|
| 420 |
+
*
|
| 421 |
+
* \since Version 1.23.
|
| 422 |
+
*/
|
| 423 |
+
ORT_API_T(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device);
|
| 424 |
+
|
| 425 |
+
/** \brief Get the device ID from an OrtMemoryDevice instance.
|
| 426 |
+
*
|
| 427 |
+
* \param[in] memory_device OrtMemoryDevice instance.
|
| 428 |
+
* \return The device ID.
|
| 429 |
+
*
|
| 430 |
+
* \since Version 1.23.
|
| 431 |
+
*/
|
| 432 |
+
ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device);
|
| 433 |
+
|
| 434 |
+
/** \brief Get the OrtSyncStreamImpl associated with an OrtSyncStream instance.
|
| 435 |
+
*
|
| 436 |
+
* This allows an the plugin library to connect its OrtSyncStreamImpl instance with an OrtSyncStream if needed.
|
| 437 |
+
*
|
| 438 |
+
* \param[in] stream The OrtSyncStream instance to find an OrtSyncStreamImpl for.
|
| 439 |
+
* \return The associated OrtSyncStreamImpl if found. nullptr otherwise.
|
| 440 |
+
*
|
| 441 |
+
* \since Version 1.23.
|
| 442 |
+
*
|
| 443 |
+
* \remarks There should always be an OrtSyncStreamImpl associated with an OrtSyncStream instance that the EP gets.
|
| 444 |
+
*/
|
| 445 |
+
ORT_API_T(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream);
|
| 446 |
+
|
| 447 |
+
/** \brief Get the current sync ID for a stream.
|
| 448 |
+
*
|
| 449 |
+
* \param[in] stream The OrtSyncStream to get the sync ID for.
|
| 450 |
+
* \return Current sync ID.
|
| 451 |
+
*
|
| 452 |
+
* \since Version 1.23.
|
| 453 |
+
*/
|
| 454 |
+
ORT_API_T(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream);
|
| 455 |
+
|
| 456 |
+
/** \brief Get the sync ID for the last time the consumer_stream waited on the producer_stream.
|
| 457 |
+
*
|
| 458 |
+
* When two streams are synchronized, the sync id represents the event used in that synchronization.
|
| 459 |
+
*
|
| 460 |
+
* \param[in] producer_stream The OrtSyncStream that produced the data.
|
| 461 |
+
* \param[in] consumer_stream The OrtSyncStream that waited on the producer_stream.
|
| 462 |
+
* \return ID for last sync. 0 if no sync has occurred between the two streams.
|
| 463 |
+
*
|
| 464 |
+
* \since Version 1.23.
|
| 465 |
+
*/
|
| 466 |
+
ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream,
|
| 467 |
+
_In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream);
|
| 468 |
+
};
|
| 469 |
+
|
| 470 |
+
/**
|
| 471 |
+
* \brief The data layout type.
|
| 472 |
+
*
|
| 473 |
+
* EPs may specify a preferred data layout type. ORT's default layout type is OrtEpDataLayout_NCHW, or
|
| 474 |
+
* OrtEpDataLayout_Default.
|
| 475 |
+
*
|
| 476 |
+
* \since Version 1.23.
|
| 477 |
+
*/
|
| 478 |
+
typedef enum OrtEpDataLayout {
|
| 479 |
+
OrtEpDataLayout_NCHW = 0,
|
| 480 |
+
OrtEpDataLayout_NHWC,
|
| 481 |
+
|
| 482 |
+
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
|
| 483 |
+
} OrtEpDataLayout;
|
| 484 |
+
|
| 485 |
+
/**
|
| 486 |
+
* \brief The OrtEp struct provides functions to implement for an execution provider.
|
| 487 |
+
* \since Version 1.22.
|
| 488 |
+
*/
|
| 489 |
+
struct OrtEp {
|
| 490 |
+
/** \brief The ONNX Runtime version the execution provider was compiled with.
|
| 491 |
+
*
|
| 492 |
+
* Implementation should set to ORT_API_VERSION.
|
| 493 |
+
* ORT will use this to ensure it does not call functions that were not available when the library was compiled.
|
| 494 |
+
*
|
| 495 |
+
* \since Version 1.22.
|
| 496 |
+
*/
|
| 497 |
+
uint32_t ort_version_supported;
|
| 498 |
+
|
| 499 |
+
/** \brief Get the execution provider name.
|
| 500 |
+
*
|
| 501 |
+
* The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
|
| 502 |
+
*
|
| 503 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 504 |
+
* \return The execution provider name.
|
| 505 |
+
*
|
| 506 |
+
* \since Version 1.22.
|
| 507 |
+
*/
|
| 508 |
+
ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr);
|
| 509 |
+
|
| 510 |
+
/** \brief Get information about the nodes supported by the OrtEp instance.
|
| 511 |
+
*
|
| 512 |
+
* IMPORTANT: This is not the final version of this API function. This is currently experimental but will
|
| 513 |
+
* be stabilized by the ONNX Runtime 1.23 release.
|
| 514 |
+
*
|
| 515 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 516 |
+
* \param[in] graph The OrtGraph instance for which to populate node support. The OrtGraph could be a nested subgraph
|
| 517 |
+
* contained by a node (e.g., an If or Loop node). ONNX Runtime calls this function separately
|
| 518 |
+
* for each nested subgraph.
|
| 519 |
+
* \param[inout] graph_support_info OrtEpGraphSupportInfo instance that the implementer must fill out in order to
|
| 520 |
+
* specify the supported nodes.
|
| 521 |
+
*
|
| 522 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 523 |
+
*
|
| 524 |
+
* \since Version 1.23.
|
| 525 |
+
*/
|
| 526 |
+
ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph,
|
| 527 |
+
_Inout_ OrtEpGraphSupportInfo* graph_support_info);
|
| 528 |
+
|
| 529 |
+
/** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance
|
| 530 |
+
* for each OrtGraph in order to define its computation function.
|
| 531 |
+
*
|
| 532 |
+
* If the session is configured to generate a pre-compiled model, the execution provider must return EPContext nodes,
|
| 533 |
+
* as OrtNode instances, that ONNX Runtime uses to create a pre-compiled model, known as an "EPContext model".
|
| 534 |
+
* An EPContext model contains EPContext nodes. Each EPContext node encapsulates the pre-compiled binary data for a
|
| 535 |
+
* OrtGraph compiled for a specific execution provider. For more details about the EPContext design, refer to:
|
| 536 |
+
* \htmlonly
|
| 537 |
+
* <a href="https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html">EPContext design document.</a>
|
| 538 |
+
* \endhtmlonly
|
| 539 |
+
*
|
| 540 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 541 |
+
* \param[in] graphs Array of `count` OrtGraph instances to compile. Each graph contains only the nodes for
|
| 542 |
+
* which the execution provider indicated support. Nested subgraphs contained by a
|
| 543 |
+
* node, such as an If or Loop, have separate OrtGraph instances.
|
| 544 |
+
* \param[in] fused_nodes Array of `count` fused nodes that will replace the compiled graphs.
|
| 545 |
+
* Each fused node is an OrtNode initialized with the intended fused node name and
|
| 546 |
+
* input/output information.
|
| 547 |
+
* \param[in] count The number of OrtGraph instances to compile.
|
| 548 |
+
* \param[out] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
|
| 549 |
+
* computation function. The implementer allocates the OrtNodeComputeInfo instances.
|
| 550 |
+
* ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
|
| 551 |
+
* \param[out] ep_context_nodes Output array of `count` OrtNode instances, each representing an EPContext
|
| 552 |
+
* node for a compiled OrtGraph. The execution provider must use
|
| 553 |
+
* OrtModelEditorApi::CreateNode to create the OrtNode instances. ONNX Runtime takes
|
| 554 |
+
* ownership of the OrtNode instances, so the execution provider must NOT call
|
| 555 |
+
* OrtApi::ReleaseNode. Should be ignored if the session is not configured to generate an
|
| 556 |
+
* EPContext model.
|
| 557 |
+
*
|
| 558 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 559 |
+
*
|
| 560 |
+
* \note Do NOT cache the provided OrtGraph instances in any of the OrtNodeComputeInfo functions because the
|
| 561 |
+
* graphs are only valid for the duration of the call to Compile. Any graph/node/input/output
|
| 562 |
+
* names that are needed by the OrtNodeComputeInfo functions must be copied and stored by the OrtEp.
|
| 563 |
+
*
|
| 564 |
+
* \since Version 1.23.
|
| 565 |
+
*/
|
| 566 |
+
ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
|
| 567 |
+
_In_ const OrtNode** fused_nodes, _In_ size_t count,
|
| 568 |
+
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
|
| 569 |
+
_Out_writes_(count) OrtNode** ep_context_nodes);
|
| 570 |
+
|
| 571 |
+
/** \brief Release OrtNodeComputeInfo instances.
|
| 572 |
+
*
|
| 573 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 574 |
+
* \param[inout] node_compute_infos The OrtNodeComputeInfo instances to release.
|
| 575 |
+
* \param[in] num_node_compute_infos The number of OrtNodeComputeInfo instances.
|
| 576 |
+
*
|
| 577 |
+
* \since Version 1.23.
|
| 578 |
+
*/
|
| 579 |
+
ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr,
|
| 580 |
+
OrtNodeComputeInfo** node_compute_infos,
|
| 581 |
+
_In_ size_t num_node_compute_infos);
|
| 582 |
+
|
| 583 |
+
/** \brief Get the EP's preferred data layout.
|
| 584 |
+
*
|
| 585 |
+
* \note Implementation of this function is optional.
|
| 586 |
+
* If not implemented, ORT will assume that this EP prefers the data layout `OrtEpDataLayout::NCHW`.
|
| 587 |
+
*
|
| 588 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 589 |
+
* \param[out] preferred_data_layout The EP's preferred data layout.
|
| 590 |
+
*
|
| 591 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 592 |
+
*
|
| 593 |
+
* \since Version 1.23.
|
| 594 |
+
*/
|
| 595 |
+
ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout);
|
| 596 |
+
|
| 597 |
+
/** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout
|
| 598 |
+
* should be converted to `target_data_layout`.
|
| 599 |
+
* If the EP prefers a non-default data layout (see `GetPreferredDataLayout()`), this function will be called
|
| 600 |
+
* during layout transformation with `target_data_layout` set to the EP's preferred data layout.
|
| 601 |
+
*
|
| 602 |
+
* \note Implementation of this function is optional.
|
| 603 |
+
* If an EP prefers a non-default data layout, it may implement this to customize the specific op data layout
|
| 604 |
+
* preferences at a finer granularity.
|
| 605 |
+
*
|
| 606 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 607 |
+
* \param[in] domain The op domain. An empty string means the ONNX domain.
|
| 608 |
+
* \param[in] op_type The op type.
|
| 609 |
+
* \param[in] target_data_layout The target data layout.
|
| 610 |
+
* \param[out] should_convert Whether the associated node's data layout should be converted to `target_data_layout`.
|
| 611 |
+
* If greater than 0, convert.
|
| 612 |
+
* If 0, don't convert.
|
| 613 |
+
* Otherwise, if less than 0, leave the decision to ORT.
|
| 614 |
+
*
|
| 615 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 616 |
+
*
|
| 617 |
+
* \since Version 1.23.
|
| 618 |
+
*/
|
| 619 |
+
ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr,
|
| 620 |
+
_In_z_ const char* domain, _In_z_ const char* op_type,
|
| 621 |
+
_In_ OrtEpDataLayout target_data_layout,
|
| 622 |
+
_Outptr_ int* should_convert);
|
| 623 |
+
|
| 624 |
+
/** \brief Set dynamic options on this EP.
|
| 625 |
+
*
|
| 626 |
+
* Dynamic options can be set by the user at any time after session creation with `OrtApi::SetEpDynamicOptions()`.
|
| 627 |
+
*
|
| 628 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 629 |
+
* \param[in] option_keys The dynamic option keys.
|
| 630 |
+
* \param[in] option_values The dynamic option values.
|
| 631 |
+
* \param[in] num_options The number of dynamic options.
|
| 632 |
+
*
|
| 633 |
+
* \note Implementation of this function is optional.
|
| 634 |
+
* An EP should only implement this if it needs to handle any dynamic options.
|
| 635 |
+
*
|
| 636 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 637 |
+
*
|
| 638 |
+
* \since Version 1.23.
|
| 639 |
+
*/
|
| 640 |
+
ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr,
|
| 641 |
+
_In_reads_(num_options) const char* const* option_keys,
|
| 642 |
+
_In_reads_(num_options) const char* const* option_values,
|
| 643 |
+
_In_ size_t num_options);
|
| 644 |
+
|
| 645 |
+
/** \brief Called by ORT to notify the EP of the start of a run.
|
| 646 |
+
*
|
| 647 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 648 |
+
* \param[in] run_options The run options for this run.
|
| 649 |
+
*
|
| 650 |
+
* \note Implementation of this function is optional.
|
| 651 |
+
*
|
| 652 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 653 |
+
*
|
| 654 |
+
* \since Version 1.23.
|
| 655 |
+
*/
|
| 656 |
+
ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options);
|
| 657 |
+
|
| 658 |
+
/** \brief Called by ORT to notify the EP of the end of a run.
|
| 659 |
+
*
|
| 660 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 661 |
+
* \param[in] run_options The run options for this run.
|
| 662 |
+
* \param[in] sync_stream Whether any associated stream should be synchronized during this call.
|
| 663 |
+
* Only applicable if there is such a stream.
|
| 664 |
+
*
|
| 665 |
+
* \note Implementation of this function is optional.
|
| 666 |
+
*
|
| 667 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 668 |
+
*
|
| 669 |
+
* \since Version 1.23.
|
| 670 |
+
*/
|
| 671 |
+
ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream);
|
| 672 |
+
|
| 673 |
+
/** \brief Create an OrtAllocator for the given OrtMemoryInfo for an OrtSession.
|
| 674 |
+
*
|
| 675 |
+
* The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo.
|
| 676 |
+
* Any allocator specific options should be read from the session options.
|
| 677 |
+
*
|
| 678 |
+
* If nullptr OrtEpFactory::CreateAllocator will be used.
|
| 679 |
+
*
|
| 680 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 681 |
+
* \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr.
|
| 682 |
+
* \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used.
|
| 683 |
+
*
|
| 684 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 685 |
+
*
|
| 686 |
+
* \since Version 1.23.
|
| 687 |
+
*/
|
| 688 |
+
ORT_API2_STATUS(CreateAllocator, _In_ OrtEp* this_ptr,
|
| 689 |
+
_In_ const OrtMemoryInfo* memory_info,
|
| 690 |
+
_Outptr_result_maybenull_ OrtAllocator** allocator);
|
| 691 |
+
|
| 692 |
+
/** \brief Create a synchronization stream for the given memory device for an OrtSession.
|
| 693 |
+
*
|
| 694 |
+
* This is used to create a synchronization stream for the execution provider and is used to synchronize
|
| 695 |
+
* operations on the device during model execution.
|
| 696 |
+
* Any stream specific options should be read from the session options.
|
| 697 |
+
*
|
| 698 |
+
* If nullptr OrtEpFactory::CreateSyncStreamForDevice will be used.
|
| 699 |
+
*
|
| 700 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 701 |
+
* \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for.
|
| 702 |
+
* \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware.
|
| 703 |
+
*
|
| 704 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 705 |
+
*
|
| 706 |
+
* \since Version 1.23.
|
| 707 |
+
*/
|
| 708 |
+
ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEp* this_ptr,
|
| 709 |
+
_In_ const OrtMemoryDevice* memory_device,
|
| 710 |
+
_Outptr_ OrtSyncStreamImpl** stream);
|
| 711 |
+
|
| 712 |
+
/** \brief Get a string with details about the EP stack used to produce a compiled model.
|
| 713 |
+
*
|
| 714 |
+
* This function gets a compatibility information string that contains details about the execution provider
|
| 715 |
+
* used to compile a given model. This string can later be used with ValidateCompiledModelCompatibilityInfo
|
| 716 |
+
* to determine if a compiled model is compatible with the EP.
|
| 717 |
+
*
|
| 718 |
+
* The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
|
| 719 |
+
*
|
| 720 |
+
* \param[in] this_ptr The OrtEp instance.
|
| 721 |
+
* \param[in] graph The OrtGraph instance for which to generate compatibility information.
|
| 722 |
+
*
|
| 723 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 724 |
+
*
|
| 725 |
+
* \since Version 1.23.
|
| 726 |
+
*/
|
| 727 |
+
ORT_API_T(const char*, GetCompiledModelCompatibilityInfo, _In_ OrtEp* this_ptr,
|
| 728 |
+
_In_ const OrtGraph* graph);
|
| 729 |
+
};
|
| 730 |
+
|
| 731 |
+
/** \brief The function signature that ORT will call to create OrtEpFactory instances.
|
| 732 |
+
*
|
| 733 |
+
* This must be available in a function called 'CreateEpFactories' in the execution provider library.
|
| 734 |
+
*
|
| 735 |
+
* \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary
|
| 736 |
+
* \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the
|
| 737 |
+
* version of ORT that the library was compiled against.
|
| 738 |
+
* \param[in] default_logger The default ORT logger that can be used for logging outside of an inference session.
|
| 739 |
+
* \param[in,out] factories The implementation should create and add OrtEpFactory instances to this
|
| 740 |
+
* pre-allocated array.
|
| 741 |
+
* i.e. usage is `factories[0] = new MyEpFactory();`
|
| 742 |
+
* \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`.
|
| 743 |
+
* Current default is to allow 4 factories. This can be increased in the future if needed.
|
| 744 |
+
* \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`.
|
| 745 |
+
*
|
| 746 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 747 |
+
*
|
| 748 |
+
* \since Version 1.22.
|
| 749 |
+
*/
|
| 750 |
+
typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base,
|
| 751 |
+
_In_ const OrtLogger* default_logger,
|
| 752 |
+
_Inout_ OrtEpFactory** factories, _In_ size_t max_factories,
|
| 753 |
+
_Out_ size_t* num_factories);
|
| 754 |
+
|
| 755 |
+
/** \brief The function signature that ORT will call to release an OrtEpFactory instance.
|
| 756 |
+
*
|
| 757 |
+
* This must be available in a function called 'ReleaseEpFactory' in the execution provider library.
|
| 758 |
+
*
|
| 759 |
+
* \param[in] factory The OrtEpFactory instance to release.
|
| 760 |
+
*
|
| 761 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 762 |
+
*
|
| 763 |
+
* \since Version 1.22.
|
| 764 |
+
*/
|
| 765 |
+
typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory);
|
| 766 |
+
|
| 767 |
+
/**
|
| 768 |
+
* \brief The OrtEpFactory provides functions to create and manage execution providers.
|
| 769 |
+
* \since Version 1.22.
|
| 770 |
+
*/
|
| 771 |
+
struct OrtEpFactory {
|
| 772 |
+
/** \brief The ONNX Runtime version the execution provider was compiled with.
|
| 773 |
+
*
|
| 774 |
+
* Implementation should set to ORT_API_VERSION.
|
| 775 |
+
* ORT will use this to ensure it does not call functions that were not available when the library was compiled.
|
| 776 |
+
*
|
| 777 |
+
* \since Version 1.22.
|
| 778 |
+
*/
|
| 779 |
+
uint32_t ort_version_supported;
|
| 780 |
+
|
| 781 |
+
/** \brief Get the name of the execution provider that the factory creates.
|
| 782 |
+
*
|
| 783 |
+
* The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
|
| 784 |
+
*
|
| 785 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 786 |
+
* \return The name of the execution provider the factory creates.
|
| 787 |
+
*
|
| 788 |
+
* \since Version 1.22.
|
| 789 |
+
*/
|
| 790 |
+
ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr);
|
| 791 |
+
|
| 792 |
+
/** \brief Get the name of vendor who owns the execution provider that the factory creates.
|
| 793 |
+
*
|
| 794 |
+
* The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
|
| 795 |
+
*
|
| 796 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 797 |
+
* \return vendor The vendor name of the execution provider the factory creates.
|
| 798 |
+
*
|
| 799 |
+
* \since Version 1.22.
|
| 800 |
+
*/
|
| 801 |
+
ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor
|
| 802 |
+
|
| 803 |
+
/** \brief Get information from the execution provider about OrtHardwareDevice support.
|
| 804 |
+
*
|
| 805 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 806 |
+
* Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice.
|
| 807 |
+
* \param[in] devices The OrtHardwareDevice instances that are available.
|
| 808 |
+
* \param[in] num_devices The number of OrtHardwareDevice instances.
|
| 809 |
+
* \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use.
|
| 810 |
+
* The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice
|
| 811 |
+
* instances to this pre-allocated array. ORT will take ownership of the values returned.
|
| 812 |
+
* i.e. usage is `ep_devices[0] = <ptr to OrtEpDevice created with OrtEpApi::CreateEpDevice>;`
|
| 813 |
+
* \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices.
|
| 814 |
+
* Current default is 8. This can be increased if needed.
|
| 815 |
+
* \param[out] num_ep_devices The number of EP devices added to ep_devices.
|
| 816 |
+
* \return true if the factory can create an execution provider that uses `device`.
|
| 817 |
+
*
|
| 818 |
+
* \since Version 1.22.
|
| 819 |
+
*/
|
| 820 |
+
ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr,
|
| 821 |
+
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
|
| 822 |
+
_In_ size_t num_devices,
|
| 823 |
+
_Inout_ OrtEpDevice** ep_devices,
|
| 824 |
+
_In_ size_t max_ep_devices,
|
| 825 |
+
_Out_ size_t* num_ep_devices);
|
| 826 |
+
|
| 827 |
+
/** \brief Function to create an OrtEp instance for use in a Session.
|
| 828 |
+
*
|
| 829 |
+
* ORT will call ReleaseEp to release the instance when it is no longer needed.
|
| 830 |
+
*
|
| 831 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 832 |
+
* \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use.
|
| 833 |
+
* May be a subset of the OrtHardwareDevice instances that the execution provider's factory
|
| 834 |
+
* set as supported in the call to OrtEpFactory::GetSupportedDevices.
|
| 835 |
+
* \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each
|
| 836 |
+
* device.
|
| 837 |
+
* \param[in] num_devices The number of devices the execution provider was selected for.
|
| 838 |
+
* \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the
|
| 839 |
+
* session. This will include ep_options from GetSupportedDevices as well as any
|
| 840 |
+
* user provided overrides.
|
| 841 |
+
* Execution provider options will have been added with a prefix of 'ep.[ep name].'.
|
| 842 |
+
* The OrtSessionOptions instance will NOT be valid after this call and should not be
|
| 843 |
+
* stored for later use.
|
| 844 |
+
* \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging.
|
| 845 |
+
* \param[out] ep The OrtEp instance created by the factory.
|
| 846 |
+
*
|
| 847 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 848 |
+
*
|
| 849 |
+
* \since Version 1.22.
|
| 850 |
+
*/
|
| 851 |
+
ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr,
|
| 852 |
+
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
|
| 853 |
+
_In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs,
|
| 854 |
+
_In_ size_t num_devices,
|
| 855 |
+
_In_ const OrtSessionOptions* session_options,
|
| 856 |
+
_In_ const OrtLogger* logger, _Outptr_ OrtEp** ep);
|
| 857 |
+
|
| 858 |
+
/** \brief Release the OrtEp instance.
|
| 859 |
+
*
|
| 860 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 861 |
+
* \param[in] ep The OrtEp instance to release.
|
| 862 |
+
*
|
| 863 |
+
* \since Version 1.22.
|
| 864 |
+
*/
|
| 865 |
+
ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep);
|
| 866 |
+
|
| 867 |
+
/** \brief Get the vendor id who owns the execution provider that the factory creates.
|
| 868 |
+
*
|
| 869 |
+
* This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies
|
| 870 |
+
*
|
| 871 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 872 |
+
* \return vendor_id The vendor ID of the execution provider the factory creates.
|
| 873 |
+
*
|
| 874 |
+
* \since Version 1.23.
|
| 875 |
+
*/
|
| 876 |
+
ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr);
|
| 877 |
+
|
| 878 |
+
/** \brief Get the version of the execution provider that the factory creates.
|
| 879 |
+
*
|
| 880 |
+
* The version string should adhere to the Semantic Versioning 2.0 specification
|
| 881 |
+
* (https://github.com/semver/semver/blob/v2.0.0/semver.md).
|
| 882 |
+
*
|
| 883 |
+
* The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it.
|
| 884 |
+
*
|
| 885 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 886 |
+
* \return The execution provider version string.
|
| 887 |
+
*
|
| 888 |
+
* \since Version 1.23.
|
| 889 |
+
*/
|
| 890 |
+
ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr);
|
| 891 |
+
|
| 892 |
+
/** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices.
|
| 893 |
+
*
|
| 894 |
+
* Given a compatibility info string produced during model compilation, the EP factory should determine whether the
|
| 895 |
+
* compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided
|
| 896 |
+
* must belong to the same execution provider instance that this factory creates.
|
| 897 |
+
*
|
| 898 |
+
* The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when
|
| 899 |
+
* evaluating compatibility and set `model_compatibility` accordingly.
|
| 900 |
+
*
|
| 901 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 902 |
+
* \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP.
|
| 903 |
+
* \param[in] num_devices Number of entries in `devices`.
|
| 904 |
+
* \param[in] compatibility_info The compatibility information string produced when the model was compiled.
|
| 905 |
+
* \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP.
|
| 906 |
+
*
|
| 907 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 908 |
+
*
|
| 909 |
+
* \since Version 1.23.
|
| 910 |
+
*/
|
| 911 |
+
ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr,
|
| 912 |
+
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
|
| 913 |
+
_In_ size_t num_devices,
|
| 914 |
+
_In_ const char* compatibility_info,
|
| 915 |
+
_Out_ OrtCompiledModelCompatibility* model_compatibility);
|
| 916 |
+
|
| 917 |
+
/** \brief Create an OrtAllocator that can be shared across sessions for the given OrtMemoryInfo.
|
| 918 |
+
*
|
| 919 |
+
* The factory that creates the EP is responsible for providing the allocators required by the EP.
|
| 920 |
+
* The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo.
|
| 921 |
+
*
|
| 922 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 923 |
+
* \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr.
|
| 924 |
+
* \param[in] allocator_options Optional key-value pairs for allocator options, can be nullptr.
|
| 925 |
+
* \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used.
|
| 926 |
+
*
|
| 927 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 928 |
+
*
|
| 929 |
+
* \since Version 1.23.
|
| 930 |
+
*/
|
| 931 |
+
ORT_API2_STATUS(CreateAllocator, _In_ OrtEpFactory* this_ptr,
|
| 932 |
+
_In_ const OrtMemoryInfo* memory_info,
|
| 933 |
+
_In_opt_ const OrtKeyValuePairs* allocator_options,
|
| 934 |
+
_Outptr_result_maybenull_ OrtAllocator** allocator);
|
| 935 |
+
|
| 936 |
+
/** \brief Release an OrtAllocator created by the factory.
|
| 937 |
+
*
|
| 938 |
+
* \since Version 1.23.
|
| 939 |
+
*/
|
| 940 |
+
ORT_API_T(void, ReleaseAllocator, _In_ OrtEpFactory* this_ptr, _In_ OrtAllocator* allocator);
|
| 941 |
+
|
| 942 |
+
/** \brief Create an OrtDataTransferImpl instance for the factory.
|
| 943 |
+
*
|
| 944 |
+
* This is used to create an IDataTransfer implementation that can be used to copy data between devices
|
| 945 |
+
* that the execution provider supports.
|
| 946 |
+
*
|
| 947 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 948 |
+
* \param[out] data_transfer The created OrtDataTransferImpl instance. Set to nullptr if not required.
|
| 949 |
+
*
|
| 950 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 951 |
+
*
|
| 952 |
+
* \since Version 1.23.
|
| 953 |
+
*/
|
| 954 |
+
ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr,
|
| 955 |
+
_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer);
|
| 956 |
+
|
| 957 |
+
/** \brief Check if execution providers created by the factory are stream aware.
|
| 958 |
+
*
|
| 959 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 960 |
+
* \return True if the factory creates execution providers that are stream aware and it implements CreateSyncStreamForDevice.
|
| 961 |
+
*
|
| 962 |
+
* \since Version 1.23.
|
| 963 |
+
*/
|
| 964 |
+
ORT_API_T(bool, IsStreamAware, _In_ const OrtEpFactory* this_ptr);
|
| 965 |
+
|
| 966 |
+
/** \brief Create a synchronization stream for the given memory device.
|
| 967 |
+
*
|
| 968 |
+
* This is used to create a synchronization stream for the memory device that can be used for operations outside of
|
| 969 |
+
* a session.
|
| 970 |
+
*
|
| 971 |
+
* \param[in] this_ptr The OrtEpFactory instance.
|
| 972 |
+
* \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for.
|
| 973 |
+
* \param[in] stream_options Options for stream creation. May be nullptr.
|
| 974 |
+
* \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware.
|
| 975 |
+
*
|
| 976 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
| 977 |
+
*
|
| 978 |
+
* \since Version 1.23.
|
| 979 |
+
*/
|
| 980 |
+
ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEpFactory* this_ptr,
|
| 981 |
+
_In_ const OrtMemoryDevice* memory_device,
|
| 982 |
+
_In_opt_ const OrtKeyValuePairs* stream_options,
|
| 983 |
+
_Outptr_ OrtSyncStreamImpl** stream);
|
| 984 |
+
};
|
| 985 |
+
|
| 986 |
+
#ifdef __cplusplus
|
| 987 |
+
}
|
| 988 |
+
#endif
|
v1.23.1/headers/onnxruntime_ep_device_ep_metadata_keys.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
// This file contains well-known keys for OrtEpDevice EP metadata entries.
|
| 7 |
+
// It does NOT specify all available metadata keys.
|
| 8 |
+
|
| 9 |
+
// Key for the execution provider version string. This should be available for all plugin EPs.
|
| 10 |
+
static const char* const kOrtEpDevice_EpMetadataKey_Version = "version";
|
| 11 |
+
|
| 12 |
+
// Prefix for execution provider compatibility information stored in model metadata.
|
| 13 |
+
// Used when generating EP context models to store compatibility strings for each EP.
|
| 14 |
+
// Full key format: "ep_compatibility_info.<EP_TYPE>"
|
| 15 |
+
static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info.";
|
| 16 |
+
|
| 17 |
+
// Key for the execution provider library path (for dynamically loaded EPs)
|
| 18 |
+
static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path";
|
v1.23.1/headers/onnxruntime_float16.h
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <stdint.h>
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <limits>
|
| 10 |
+
|
| 11 |
+
namespace onnxruntime_float16 {
|
| 12 |
+
|
| 13 |
+
namespace detail {
|
| 14 |
+
|
| 15 |
+
enum class endian {
|
| 16 |
+
#if defined(_WIN32)
|
| 17 |
+
little = 0,
|
| 18 |
+
big = 1,
|
| 19 |
+
native = little,
|
| 20 |
+
#elif defined(__GNUC__) || defined(__clang__)
|
| 21 |
+
little = __ORDER_LITTLE_ENDIAN__,
|
| 22 |
+
big = __ORDER_BIG_ENDIAN__,
|
| 23 |
+
native = __BYTE_ORDER__,
|
| 24 |
+
#else
|
| 25 |
+
#error onnxruntime_float16::detail::endian is not implemented in this environment.
|
| 26 |
+
#endif
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
static_assert(
|
| 30 |
+
endian::native == endian::little || endian::native == endian::big,
|
| 31 |
+
"Only little-endian or big-endian native byte orders are supported.");
|
| 32 |
+
|
| 33 |
+
} // namespace detail
|
| 34 |
+
|
| 35 |
+
/// <summary>
|
| 36 |
+
/// Shared implementation between public and internal classes. CRTP pattern.
|
| 37 |
+
/// </summary>
|
| 38 |
+
template <class Derived>
|
| 39 |
+
struct Float16Impl {
|
| 40 |
+
protected:
|
| 41 |
+
/// <summary>
|
| 42 |
+
/// Converts from float to uint16_t float16 representation
|
| 43 |
+
/// </summary>
|
| 44 |
+
/// <param name="v"></param>
|
| 45 |
+
/// <returns></returns>
|
| 46 |
+
constexpr static uint16_t ToUint16Impl(float v) noexcept;
|
| 47 |
+
|
| 48 |
+
/// <summary>
|
| 49 |
+
/// Converts float16 to float
|
| 50 |
+
/// </summary>
|
| 51 |
+
/// <returns>float representation of float16 value</returns>
|
| 52 |
+
float ToFloatImpl() const noexcept;
|
| 53 |
+
|
| 54 |
+
/// <summary>
|
| 55 |
+
/// Creates an instance that represents absolute value.
|
| 56 |
+
/// </summary>
|
| 57 |
+
/// <returns>Absolute value</returns>
|
| 58 |
+
uint16_t AbsImpl() const noexcept {
|
| 59 |
+
return static_cast<uint16_t>(val & ~kSignMask);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/// <summary>
|
| 63 |
+
/// Creates a new instance with the sign flipped.
|
| 64 |
+
/// </summary>
|
| 65 |
+
/// <returns>Flipped sign instance</returns>
|
| 66 |
+
uint16_t NegateImpl() const noexcept {
|
| 67 |
+
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
public:
|
| 71 |
+
// uint16_t special values
|
| 72 |
+
static constexpr uint16_t kSignMask = 0x8000U;
|
| 73 |
+
static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
|
| 74 |
+
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
|
| 75 |
+
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
|
| 76 |
+
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
|
| 77 |
+
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
|
| 78 |
+
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
|
| 79 |
+
static constexpr uint16_t kOneBits = 0x3C00U;
|
| 80 |
+
static constexpr uint16_t kMinusOneBits = 0xBC00U;
|
| 81 |
+
|
| 82 |
+
uint16_t val{0};
|
| 83 |
+
|
| 84 |
+
Float16Impl() = default;
|
| 85 |
+
|
| 86 |
+
/// <summary>
|
| 87 |
+
/// Checks if the value is negative
|
| 88 |
+
/// </summary>
|
| 89 |
+
/// <returns>true if negative</returns>
|
| 90 |
+
bool IsNegative() const noexcept {
|
| 91 |
+
return static_cast<int16_t>(val) < 0;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
/// <summary>
|
| 95 |
+
/// Tests if the value is NaN
|
| 96 |
+
/// </summary>
|
| 97 |
+
/// <returns>true if NaN</returns>
|
| 98 |
+
bool IsNaN() const noexcept {
|
| 99 |
+
return AbsImpl() > kPositiveInfinityBits;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/// <summary>
|
| 103 |
+
/// Tests if the value is finite
|
| 104 |
+
/// </summary>
|
| 105 |
+
/// <returns>true if finite</returns>
|
| 106 |
+
bool IsFinite() const noexcept {
|
| 107 |
+
return AbsImpl() < kPositiveInfinityBits;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
/// <summary>
|
| 111 |
+
/// Tests if the value represents positive infinity.
|
| 112 |
+
/// </summary>
|
| 113 |
+
/// <returns>true if positive infinity</returns>
|
| 114 |
+
bool IsPositiveInfinity() const noexcept {
|
| 115 |
+
return val == kPositiveInfinityBits;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/// <summary>
|
| 119 |
+
/// Tests if the value represents negative infinity
|
| 120 |
+
/// </summary>
|
| 121 |
+
/// <returns>true if negative infinity</returns>
|
| 122 |
+
bool IsNegativeInfinity() const noexcept {
|
| 123 |
+
return val == kNegativeInfinityBits;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
/// <summary>
|
| 127 |
+
/// Tests if the value is either positive or negative infinity.
|
| 128 |
+
/// </summary>
|
| 129 |
+
/// <returns>True if absolute value is infinity</returns>
|
| 130 |
+
bool IsInfinity() const noexcept {
|
| 131 |
+
return AbsImpl() == kPositiveInfinityBits;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/// <summary>
|
| 135 |
+
/// Tests if the value is NaN or zero. Useful for comparisons.
|
| 136 |
+
/// </summary>
|
| 137 |
+
/// <returns>True if NaN or zero.</returns>
|
| 138 |
+
bool IsNaNOrZero() const noexcept {
|
| 139 |
+
auto abs = AbsImpl();
|
| 140 |
+
return (abs == 0 || abs > kPositiveInfinityBits);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/// <summary>
|
| 144 |
+
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
| 145 |
+
/// </summary>
|
| 146 |
+
/// <returns>True if so</returns>
|
| 147 |
+
bool IsNormal() const noexcept {
|
| 148 |
+
auto abs = AbsImpl();
|
| 149 |
+
return (abs < kPositiveInfinityBits) // is finite
|
| 150 |
+
&& (abs != 0) // is not zero
|
| 151 |
+
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// <summary>
|
| 155 |
+
/// Tests if the value is subnormal (denormal).
|
| 156 |
+
/// </summary>
|
| 157 |
+
/// <returns>True if so</returns>
|
| 158 |
+
bool IsSubnormal() const noexcept {
|
| 159 |
+
auto abs = AbsImpl();
|
| 160 |
+
return (abs < kPositiveInfinityBits) // is finite
|
| 161 |
+
&& (abs != 0) // is not zero
|
| 162 |
+
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
/// <summary>
|
| 166 |
+
/// Creates an instance that represents absolute value.
|
| 167 |
+
/// </summary>
|
| 168 |
+
/// <returns>Absolute value</returns>
|
| 169 |
+
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
| 170 |
+
|
| 171 |
+
/// <summary>
|
| 172 |
+
/// Creates a new instance with the sign flipped.
|
| 173 |
+
/// </summary>
|
| 174 |
+
/// <returns>Flipped sign instance</returns>
|
| 175 |
+
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
| 176 |
+
|
| 177 |
+
/// <summary>
|
| 178 |
+
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
| 179 |
+
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
| 180 |
+
/// and therefore equivalent, if the resulting value is still zero.
|
| 181 |
+
/// </summary>
|
| 182 |
+
/// <param name="lhs">first value</param>
|
| 183 |
+
/// <param name="rhs">second value</param>
|
| 184 |
+
/// <returns>True if both arguments represent zero</returns>
|
| 185 |
+
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
|
| 186 |
+
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
bool operator==(const Float16Impl& rhs) const noexcept {
|
| 190 |
+
if (IsNaN() || rhs.IsNaN()) {
|
| 191 |
+
// IEEE defines that NaN is not equal to anything, including itself.
|
| 192 |
+
return false;
|
| 193 |
+
}
|
| 194 |
+
return val == rhs.val;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
|
| 198 |
+
|
| 199 |
+
bool operator<(const Float16Impl& rhs) const noexcept {
|
| 200 |
+
if (IsNaN() || rhs.IsNaN()) {
|
| 201 |
+
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
| 202 |
+
return false;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
const bool left_is_negative = IsNegative();
|
| 206 |
+
if (left_is_negative != rhs.IsNegative()) {
|
| 207 |
+
// When the signs of left and right differ, we know that left is less than right if it is
|
| 208 |
+
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
| 209 |
+
// says they should be equal, even if the signs differ.
|
| 210 |
+
return left_is_negative && !AreZero(*this, rhs);
|
| 211 |
+
}
|
| 212 |
+
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
| 213 |
+
}
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
// The following Float16_t conversions are based on the code from
|
| 217 |
+
// Eigen library.
|
| 218 |
+
|
| 219 |
+
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
|
| 220 |
+
// The original license follows:
|
| 221 |
+
//
|
| 222 |
+
// Copyright (c) Fabian Giesen, 2016
|
| 223 |
+
// All rights reserved.
|
| 224 |
+
// Redistribution and use in source and binary forms, with or without
|
| 225 |
+
// modification, are permitted.
|
| 226 |
+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 227 |
+
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 228 |
+
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 229 |
+
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 230 |
+
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 231 |
+
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 232 |
+
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 233 |
+
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 234 |
+
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 235 |
+
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 236 |
+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 237 |
+
|
| 238 |
+
namespace detail {
|
| 239 |
+
union float32_bits {
|
| 240 |
+
unsigned int u;
|
| 241 |
+
float f;
|
| 242 |
+
};
|
| 243 |
+
} // namespace detail
|
| 244 |
+
|
| 245 |
+
template <class Derived>
|
| 246 |
+
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
| 247 |
+
detail::float32_bits f{};
|
| 248 |
+
f.f = v;
|
| 249 |
+
|
| 250 |
+
constexpr detail::float32_bits f32infty = {255 << 23};
|
| 251 |
+
constexpr detail::float32_bits f16max = {(127 + 16) << 23};
|
| 252 |
+
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
|
| 253 |
+
constexpr unsigned int sign_mask = 0x80000000u;
|
| 254 |
+
uint16_t val = static_cast<uint16_t>(0x0u);
|
| 255 |
+
|
| 256 |
+
unsigned int sign = f.u & sign_mask;
|
| 257 |
+
f.u ^= sign;
|
| 258 |
+
|
| 259 |
+
// NOTE all the integer compares in this function can be safely
|
| 260 |
+
// compiled into signed compares since all operands are below
|
| 261 |
+
// 0x80000000. Important if you want fast straight SSE2 code
|
| 262 |
+
// (since there's no unsigned PCMPGTD).
|
| 263 |
+
|
| 264 |
+
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
|
| 265 |
+
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
| 266 |
+
} else { // (De)normalized number or zero
|
| 267 |
+
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
|
| 268 |
+
// use a magic value to align our 10 mantissa bits at the bottom of
|
| 269 |
+
// the float. as long as FP addition is round-to-nearest-even this
|
| 270 |
+
// just works.
|
| 271 |
+
f.f += denorm_magic.f;
|
| 272 |
+
|
| 273 |
+
// and one integer subtract of the bias later, we have our final float!
|
| 274 |
+
val = static_cast<uint16_t>(f.u - denorm_magic.u);
|
| 275 |
+
} else {
|
| 276 |
+
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
| 277 |
+
|
| 278 |
+
// update exponent, rounding bias part 1
|
| 279 |
+
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
| 280 |
+
// without arithmetic overflow.
|
| 281 |
+
f.u += 0xc8000fffU;
|
| 282 |
+
// rounding bias part 2
|
| 283 |
+
f.u += mant_odd;
|
| 284 |
+
// take the bits!
|
| 285 |
+
val = static_cast<uint16_t>(f.u >> 13);
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
val |= static_cast<uint16_t>(sign >> 16);
|
| 290 |
+
return val;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
template <class Derived>
|
| 294 |
+
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
|
| 295 |
+
constexpr detail::float32_bits magic = {113 << 23};
|
| 296 |
+
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
| 297 |
+
detail::float32_bits o{};
|
| 298 |
+
|
| 299 |
+
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
|
| 300 |
+
unsigned int exp = shifted_exp & o.u; // just the exponent
|
| 301 |
+
o.u += (127 - 15) << 23; // exponent adjust
|
| 302 |
+
|
| 303 |
+
// handle exponent special cases
|
| 304 |
+
if (exp == shifted_exp) { // Inf/NaN?
|
| 305 |
+
o.u += (128 - 16) << 23; // extra exp adjust
|
| 306 |
+
} else if (exp == 0) { // Zero/Denormal?
|
| 307 |
+
o.u += 1 << 23; // extra exp adjust
|
| 308 |
+
o.f -= magic.f; // re-normalize
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
// Attempt to workaround the Internal Compiler Error on ARM64
|
| 312 |
+
// for bitwise | operator, including std::bitset
|
| 313 |
+
#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
|
| 314 |
+
if (IsNegative()) {
|
| 315 |
+
return -o.f;
|
| 316 |
+
}
|
| 317 |
+
#else
|
| 318 |
+
// original code:
|
| 319 |
+
o.u |= (val & 0x8000U) << 16U; // sign bit
|
| 320 |
+
#endif
|
| 321 |
+
return o.f;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/// Shared implementation between public and internal classes. CRTP pattern.
|
| 325 |
+
template <class Derived>
|
| 326 |
+
struct BFloat16Impl {
|
| 327 |
+
protected:
|
| 328 |
+
/// <summary>
|
| 329 |
+
/// Converts from float to uint16_t float16 representation
|
| 330 |
+
/// </summary>
|
| 331 |
+
/// <param name="v"></param>
|
| 332 |
+
/// <returns></returns>
|
| 333 |
+
static uint16_t ToUint16Impl(float v) noexcept;
|
| 334 |
+
|
| 335 |
+
/// <summary>
|
| 336 |
+
/// Converts bfloat16 to float
|
| 337 |
+
/// </summary>
|
| 338 |
+
/// <returns>float representation of bfloat16 value</returns>
|
| 339 |
+
float ToFloatImpl() const noexcept;
|
| 340 |
+
|
| 341 |
+
/// <summary>
|
| 342 |
+
/// Creates an instance that represents absolute value.
|
| 343 |
+
/// </summary>
|
| 344 |
+
/// <returns>Absolute value</returns>
|
| 345 |
+
uint16_t AbsImpl() const noexcept {
|
| 346 |
+
return static_cast<uint16_t>(val & ~kSignMask);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
/// <summary>
|
| 350 |
+
/// Creates a new instance with the sign flipped.
|
| 351 |
+
/// </summary>
|
| 352 |
+
/// <returns>Flipped sign instance</returns>
|
| 353 |
+
uint16_t NegateImpl() const noexcept {
|
| 354 |
+
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
public:
|
| 358 |
+
// uint16_t special values
|
| 359 |
+
static constexpr uint16_t kSignMask = 0x8000U;
|
| 360 |
+
static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
|
| 361 |
+
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
|
| 362 |
+
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
|
| 363 |
+
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
|
| 364 |
+
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
|
| 365 |
+
static constexpr uint16_t kMaxValueBits = 0x7F7FU;
|
| 366 |
+
static constexpr uint16_t kRoundToNearest = 0x7FFFU;
|
| 367 |
+
static constexpr uint16_t kOneBits = 0x3F80U;
|
| 368 |
+
static constexpr uint16_t kMinusOneBits = 0xBF80U;
|
| 369 |
+
|
| 370 |
+
uint16_t val{0};
|
| 371 |
+
|
| 372 |
+
BFloat16Impl() = default;
|
| 373 |
+
|
| 374 |
+
/// <summary>
|
| 375 |
+
/// Checks if the value is negative
|
| 376 |
+
/// </summary>
|
| 377 |
+
/// <returns>true if negative</returns>
|
| 378 |
+
bool IsNegative() const noexcept {
|
| 379 |
+
return static_cast<int16_t>(val) < 0;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
/// <summary>
|
| 383 |
+
/// Tests if the value is NaN
|
| 384 |
+
/// </summary>
|
| 385 |
+
/// <returns>true if NaN</returns>
|
| 386 |
+
bool IsNaN() const noexcept {
|
| 387 |
+
return AbsImpl() > kPositiveInfinityBits;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
/// <summary>
|
| 391 |
+
/// Tests if the value is finite
|
| 392 |
+
/// </summary>
|
| 393 |
+
/// <returns>true if finite</returns>
|
| 394 |
+
bool IsFinite() const noexcept {
|
| 395 |
+
return AbsImpl() < kPositiveInfinityBits;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
/// <summary>
|
| 399 |
+
/// Tests if the value represents positive infinity.
|
| 400 |
+
/// </summary>
|
| 401 |
+
/// <returns>true if positive infinity</returns>
|
| 402 |
+
bool IsPositiveInfinity() const noexcept {
|
| 403 |
+
return val == kPositiveInfinityBits;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
/// <summary>
|
| 407 |
+
/// Tests if the value represents negative infinity
|
| 408 |
+
/// </summary>
|
| 409 |
+
/// <returns>true if negative infinity</returns>
|
| 410 |
+
bool IsNegativeInfinity() const noexcept {
|
| 411 |
+
return val == kNegativeInfinityBits;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
/// <summary>
|
| 415 |
+
/// Tests if the value is either positive or negative infinity.
|
| 416 |
+
/// </summary>
|
| 417 |
+
/// <returns>True if absolute value is infinity</returns>
|
| 418 |
+
bool IsInfinity() const noexcept {
|
| 419 |
+
return AbsImpl() == kPositiveInfinityBits;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
/// <summary>
|
| 423 |
+
/// Tests if the value is NaN or zero. Useful for comparisons.
|
| 424 |
+
/// </summary>
|
| 425 |
+
/// <returns>True if NaN or zero.</returns>
|
| 426 |
+
bool IsNaNOrZero() const noexcept {
|
| 427 |
+
auto abs = AbsImpl();
|
| 428 |
+
return (abs == 0 || abs > kPositiveInfinityBits);
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// <summary>
|
| 432 |
+
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
| 433 |
+
/// </summary>
|
| 434 |
+
/// <returns>True if so</returns>
|
| 435 |
+
bool IsNormal() const noexcept {
|
| 436 |
+
auto abs = AbsImpl();
|
| 437 |
+
return (abs < kPositiveInfinityBits) // is finite
|
| 438 |
+
&& (abs != 0) // is not zero
|
| 439 |
+
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/// <summary>
|
| 443 |
+
/// Tests if the value is subnormal (denormal).
|
| 444 |
+
/// </summary>
|
| 445 |
+
/// <returns>True if so</returns>
|
| 446 |
+
bool IsSubnormal() const noexcept {
|
| 447 |
+
auto abs = AbsImpl();
|
| 448 |
+
return (abs < kPositiveInfinityBits) // is finite
|
| 449 |
+
&& (abs != 0) // is not zero
|
| 450 |
+
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
/// <summary>
|
| 454 |
+
/// Creates an instance that represents absolute value.
|
| 455 |
+
/// </summary>
|
| 456 |
+
/// <returns>Absolute value</returns>
|
| 457 |
+
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
| 458 |
+
|
| 459 |
+
/// <summary>
|
| 460 |
+
/// Creates a new instance with the sign flipped.
|
| 461 |
+
/// </summary>
|
| 462 |
+
/// <returns>Flipped sign instance</returns>
|
| 463 |
+
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
| 464 |
+
|
| 465 |
+
/// <summary>
|
| 466 |
+
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
| 467 |
+
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
| 468 |
+
/// and therefore equivalent, if the resulting value is still zero.
|
| 469 |
+
/// </summary>
|
| 470 |
+
/// <param name="lhs">first value</param>
|
| 471 |
+
/// <param name="rhs">second value</param>
|
| 472 |
+
/// <returns>True if both arguments represent zero</returns>
|
| 473 |
+
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
|
| 474 |
+
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
| 475 |
+
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
| 476 |
+
// and therefore equivalent, if the resulting value is still zero.
|
| 477 |
+
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
| 478 |
+
}
|
| 479 |
+
};
|
| 480 |
+
|
| 481 |
+
template <class Derived>
|
| 482 |
+
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
| 483 |
+
uint16_t result;
|
| 484 |
+
if (std::isnan(v)) {
|
| 485 |
+
result = kPositiveQNaNBits;
|
| 486 |
+
} else {
|
| 487 |
+
auto get_msb_half = [](float fl) {
|
| 488 |
+
uint16_t result;
|
| 489 |
+
#ifdef __cpp_if_constexpr
|
| 490 |
+
if constexpr (detail::endian::native == detail::endian::little) {
|
| 491 |
+
#else
|
| 492 |
+
if (detail::endian::native == detail::endian::little) {
|
| 493 |
+
#endif
|
| 494 |
+
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
|
| 495 |
+
} else {
|
| 496 |
+
std::memcpy(&result, &fl, sizeof(uint16_t));
|
| 497 |
+
}
|
| 498 |
+
return result;
|
| 499 |
+
};
|
| 500 |
+
|
| 501 |
+
uint16_t upper_bits = get_msb_half(v);
|
| 502 |
+
union {
|
| 503 |
+
uint32_t U32;
|
| 504 |
+
float F32;
|
| 505 |
+
};
|
| 506 |
+
F32 = v;
|
| 507 |
+
U32 += (upper_bits & 1) + kRoundToNearest;
|
| 508 |
+
result = get_msb_half(F32);
|
| 509 |
+
}
|
| 510 |
+
return result;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
template <class Derived>
|
| 514 |
+
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
|
| 515 |
+
if (IsNaN()) {
|
| 516 |
+
return std::numeric_limits<float>::quiet_NaN();
|
| 517 |
+
}
|
| 518 |
+
float result;
|
| 519 |
+
char* const first = reinterpret_cast<char*>(&result);
|
| 520 |
+
char* const second = first + sizeof(uint16_t);
|
| 521 |
+
#ifdef __cpp_if_constexpr
|
| 522 |
+
if constexpr (detail::endian::native == detail::endian::little) {
|
| 523 |
+
#else
|
| 524 |
+
if (detail::endian::native == detail::endian::little) {
|
| 525 |
+
#endif
|
| 526 |
+
std::memset(first, 0, sizeof(uint16_t));
|
| 527 |
+
std::memcpy(second, &val, sizeof(uint16_t));
|
| 528 |
+
} else {
|
| 529 |
+
std::memcpy(first, &val, sizeof(uint16_t));
|
| 530 |
+
std::memset(second, 0, sizeof(uint16_t));
|
| 531 |
+
}
|
| 532 |
+
return result;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
} // namespace onnxruntime_float16
|
v1.23.1/headers/onnxruntime_lite_custom_op.h
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
// Summary
|
| 5 |
+
// The header has APIs to save custom op authors the trouble of defining schemas,
|
| 6 |
+
// which will be inferred by functions' signature, as long as their argument list has types supported here.
|
| 7 |
+
// Input could be:
|
| 8 |
+
// 1. Tensor of onnx data types.
|
| 9 |
+
// 2. Span of onnx data types.
|
| 10 |
+
// 3. Scalar of onnx data types.
|
| 11 |
+
// A input could be optional if indicated as std::optional<...>.
|
| 12 |
+
// For an output, it must be a tensor of onnx data types.
|
| 13 |
+
// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
|
| 14 |
+
// For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
| 15 |
+
// Note - all APIs in this header are ABI.
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
#include "onnxruntime_cxx_api.h"
|
| 19 |
+
#include <optional>
|
| 20 |
+
#include <numeric>
|
| 21 |
+
#include <functional>
|
| 22 |
+
#include <unordered_set>
|
| 23 |
+
|
| 24 |
+
namespace Ort {
|
| 25 |
+
namespace Custom {
|
| 26 |
+
|
| 27 |
+
class ArgBase {
|
| 28 |
+
public:
|
| 29 |
+
ArgBase(OrtKernelContext* ctx,
|
| 30 |
+
size_t indice,
|
| 31 |
+
bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
|
| 32 |
+
virtual ~ArgBase() {};
|
| 33 |
+
|
| 34 |
+
protected:
|
| 35 |
+
struct KernelContext ctx_;
|
| 36 |
+
size_t indice_;
|
| 37 |
+
bool is_input_;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
using ArgPtr = std::unique_ptr<Custom::ArgBase>;
|
| 41 |
+
using ArgPtrs = std::vector<ArgPtr>;
|
| 42 |
+
|
| 43 |
+
class TensorBase : public ArgBase {
|
| 44 |
+
public:
|
| 45 |
+
TensorBase(OrtKernelContext* ctx,
|
| 46 |
+
size_t indice,
|
| 47 |
+
bool is_input) : ArgBase(ctx, indice, is_input) {}
|
| 48 |
+
|
| 49 |
+
operator bool() const {
|
| 50 |
+
return shape_.has_value();
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
const std::vector<int64_t>& Shape() const {
|
| 54 |
+
if (!shape_.has_value()) {
|
| 55 |
+
ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 56 |
+
}
|
| 57 |
+
return shape_.value();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
ONNXTensorElementDataType Type() const {
|
| 61 |
+
return type_;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
int64_t NumberOfElement() const {
|
| 65 |
+
if (shape_.has_value()) {
|
| 66 |
+
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
|
| 67 |
+
} else {
|
| 68 |
+
return 0;
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
std::string Shape2Str() const {
|
| 73 |
+
if (shape_.has_value()) {
|
| 74 |
+
std::string shape_str;
|
| 75 |
+
for (const auto& dim : *shape_) {
|
| 76 |
+
shape_str.append(std::to_string(dim));
|
| 77 |
+
shape_str.append(", ");
|
| 78 |
+
}
|
| 79 |
+
return shape_str;
|
| 80 |
+
} else {
|
| 81 |
+
return "empty";
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
bool IsCpuTensor() const {
|
| 86 |
+
return strcmp("Cpu", mem_type_) == 0;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
virtual const void* DataRaw() const = 0;
|
| 90 |
+
virtual size_t SizeInBytes() const = 0;
|
| 91 |
+
|
| 92 |
+
protected:
|
| 93 |
+
std::optional<std::vector<int64_t>> shape_;
|
| 94 |
+
ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
| 95 |
+
const char* mem_type_ = "Cpu";
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
template <typename T>
|
| 99 |
+
struct Span {
|
| 100 |
+
const T* data_ = {};
|
| 101 |
+
size_t size_ = {};
|
| 102 |
+
void Assign(const T* data, size_t size) {
|
| 103 |
+
data_ = data;
|
| 104 |
+
size_ = size;
|
| 105 |
+
}
|
| 106 |
+
size_t size() const { return size_; }
|
| 107 |
+
T operator[](size_t indice) const {
|
| 108 |
+
return data_[indice];
|
| 109 |
+
}
|
| 110 |
+
const T* data() const { return data_; }
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
template <typename T>
|
| 114 |
+
class Tensor : public TensorBase {
|
| 115 |
+
public:
|
| 116 |
+
using TT = typename std::remove_reference<T>::type;
|
| 117 |
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
| 118 |
+
if (is_input_) {
|
| 119 |
+
if (indice >= ctx_.GetInputCount()) {
|
| 120 |
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
| 121 |
+
}
|
| 122 |
+
const_value_ = ctx_.GetInput(indice);
|
| 123 |
+
auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
|
| 124 |
+
shape_ = type_shape_info.GetShape();
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
const TT* Data() const {
|
| 128 |
+
return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
|
| 129 |
+
}
|
| 130 |
+
TT* Allocate(const std::vector<int64_t>& shape) {
|
| 131 |
+
shape_ = shape;
|
| 132 |
+
if (!data_) {
|
| 133 |
+
shape_ = shape;
|
| 134 |
+
data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
|
| 135 |
+
}
|
| 136 |
+
return data_;
|
| 137 |
+
}
|
| 138 |
+
static TT GetT() { return (TT)0; }
|
| 139 |
+
const Span<T>& AsSpan() {
|
| 140 |
+
if (!shape_.has_value() || shape_->size() != 1) {
|
| 141 |
+
ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
|
| 142 |
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 143 |
+
}
|
| 144 |
+
span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
|
| 145 |
+
return span_;
|
| 146 |
+
}
|
| 147 |
+
const T& AsScalar() {
|
| 148 |
+
if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
|
| 149 |
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
|
| 150 |
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 151 |
+
}
|
| 152 |
+
return *Data();
|
| 153 |
+
}
|
| 154 |
+
const void* DataRaw() const override {
|
| 155 |
+
return reinterpret_cast<const void*>(Data());
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
size_t SizeInBytes() const override {
|
| 159 |
+
return sizeof(TT) * static_cast<size_t>(NumberOfElement());
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
private:
|
| 163 |
+
ConstValue const_value_; // for input
|
| 164 |
+
TT* data_{}; // for output
|
| 165 |
+
Span<T> span_;
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
template <>
|
| 169 |
+
class Tensor<std::string> : public TensorBase {
|
| 170 |
+
public:
|
| 171 |
+
using strings = std::vector<std::string>;
|
| 172 |
+
|
| 173 |
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
| 174 |
+
if (is_input_) {
|
| 175 |
+
if (indice >= ctx_.GetInputCount()) {
|
| 176 |
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
| 177 |
+
}
|
| 178 |
+
auto const_value = ctx_.GetInput(indice);
|
| 179 |
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
| 180 |
+
shape_ = type_shape_info.GetShape();
|
| 181 |
+
auto num_chars = const_value.GetStringTensorDataLength();
|
| 182 |
+
// note - there will be copy ...
|
| 183 |
+
auto num_strings = static_cast<size_t>(NumberOfElement());
|
| 184 |
+
if (num_strings) {
|
| 185 |
+
std::vector<char> chars(num_chars + 1, '\0');
|
| 186 |
+
std::vector<size_t> offsets(num_strings);
|
| 187 |
+
const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
|
| 188 |
+
auto upper_bound = num_strings - 1;
|
| 189 |
+
input_strings_.resize(num_strings);
|
| 190 |
+
for (size_t i = upper_bound;; --i) {
|
| 191 |
+
if (i < upper_bound) {
|
| 192 |
+
chars[offsets[i + 1]] = '\0';
|
| 193 |
+
}
|
| 194 |
+
input_strings_[i] = chars.data() + offsets[i];
|
| 195 |
+
if (0 == i) {
|
| 196 |
+
break;
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
const strings& Data() const {
|
| 203 |
+
return input_strings_;
|
| 204 |
+
}
|
| 205 |
+
const void* DataRaw() const override {
|
| 206 |
+
if (input_strings_.size() != 1) {
|
| 207 |
+
ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
| 208 |
+
}
|
| 209 |
+
return reinterpret_cast<const void*>(input_strings_[0].c_str());
|
| 210 |
+
}
|
| 211 |
+
size_t SizeInBytes() const override {
|
| 212 |
+
if (input_strings_.size() != 1) {
|
| 213 |
+
ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
| 214 |
+
}
|
| 215 |
+
return input_strings_[0].size();
|
| 216 |
+
}
|
| 217 |
+
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
| 218 |
+
shape_ = dims;
|
| 219 |
+
std::vector<const char*> raw;
|
| 220 |
+
for (const auto& s : ss) {
|
| 221 |
+
raw.push_back(s.data());
|
| 222 |
+
}
|
| 223 |
+
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
| 224 |
+
// note - there will be copy ...
|
| 225 |
+
output.FillStringTensor(raw.data(), raw.size());
|
| 226 |
+
}
|
| 227 |
+
const Span<std::string>& AsSpan() {
|
| 228 |
+
ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 229 |
+
}
|
| 230 |
+
const std::string& AsScalar() {
|
| 231 |
+
if (input_strings_.size() != 1) {
|
| 232 |
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
|
| 233 |
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 234 |
+
}
|
| 235 |
+
return input_strings_[0];
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
private:
|
| 239 |
+
std::vector<std::string> input_strings_; // for input
|
| 240 |
+
};
|
| 241 |
+
|
| 242 |
+
template <>
|
| 243 |
+
class Tensor<std::string_view> : public TensorBase {
|
| 244 |
+
public:
|
| 245 |
+
using strings = std::vector<std::string>;
|
| 246 |
+
using string_views = std::vector<std::string_view>;
|
| 247 |
+
|
| 248 |
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
| 249 |
+
if (is_input_) {
|
| 250 |
+
if (indice >= ctx_.GetInputCount()) {
|
| 251 |
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
| 252 |
+
}
|
| 253 |
+
auto const_value = ctx_.GetInput(indice);
|
| 254 |
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
| 255 |
+
shape_ = type_shape_info.GetShape();
|
| 256 |
+
auto num_chars = const_value.GetStringTensorDataLength();
|
| 257 |
+
chars_.resize(num_chars + 1, '\0');
|
| 258 |
+
auto num_strings = static_cast<size_t>(NumberOfElement());
|
| 259 |
+
if (num_strings) {
|
| 260 |
+
std::vector<size_t> offsets(num_strings);
|
| 261 |
+
const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
|
| 262 |
+
offsets.push_back(num_chars);
|
| 263 |
+
for (size_t i = 0; i < num_strings; ++i) {
|
| 264 |
+
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
const string_views& Data() const {
|
| 270 |
+
return input_string_views_;
|
| 271 |
+
}
|
| 272 |
+
const void* DataRaw() const override {
|
| 273 |
+
if (input_string_views_.size() != 1) {
|
| 274 |
+
ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
| 275 |
+
}
|
| 276 |
+
return reinterpret_cast<const void*>(input_string_views_[0].data());
|
| 277 |
+
}
|
| 278 |
+
size_t SizeInBytes() const override {
|
| 279 |
+
if (input_string_views_.size() != 1) {
|
| 280 |
+
ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
| 281 |
+
}
|
| 282 |
+
return input_string_views_[0].size();
|
| 283 |
+
}
|
| 284 |
+
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
| 285 |
+
shape_ = dims;
|
| 286 |
+
std::vector<const char*> raw;
|
| 287 |
+
for (const auto& s : ss) {
|
| 288 |
+
raw.push_back(s.data());
|
| 289 |
+
}
|
| 290 |
+
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
| 291 |
+
// note - there will be copy ...
|
| 292 |
+
output.FillStringTensor(raw.data(), raw.size());
|
| 293 |
+
}
|
| 294 |
+
const Span<std::string_view>& AsSpan() {
|
| 295 |
+
ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 296 |
+
}
|
| 297 |
+
std::string_view AsScalar() {
|
| 298 |
+
if (input_string_views_.size() != 1) {
|
| 299 |
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
|
| 300 |
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
| 301 |
+
}
|
| 302 |
+
return input_string_views_[0];
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
private:
|
| 306 |
+
std::vector<char> chars_; // for input
|
| 307 |
+
std::vector<std::string_view> input_string_views_; // for input
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
|
| 311 |
+
using TensorPtrs = std::vector<TensorPtr>;
|
| 312 |
+
|
| 313 |
+
struct TensorArray : public ArgBase {
|
| 314 |
+
TensorArray(OrtKernelContext* ctx,
|
| 315 |
+
size_t start_indice,
|
| 316 |
+
bool is_input) : ArgBase(ctx,
|
| 317 |
+
start_indice,
|
| 318 |
+
is_input) {
|
| 319 |
+
if (is_input) {
|
| 320 |
+
auto input_count = ctx_.GetInputCount();
|
| 321 |
+
for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
|
| 322 |
+
auto const_value = ctx_.GetInput(start_indice);
|
| 323 |
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
| 324 |
+
auto type = type_shape_info.GetElementType();
|
| 325 |
+
TensorPtr tensor;
|
| 326 |
+
switch (type) {
|
| 327 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
| 328 |
+
tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
|
| 329 |
+
break;
|
| 330 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
| 331 |
+
tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
|
| 332 |
+
break;
|
| 333 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
| 334 |
+
tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
|
| 335 |
+
break;
|
| 336 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
| 337 |
+
tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
|
| 338 |
+
break;
|
| 339 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
| 340 |
+
tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
|
| 341 |
+
break;
|
| 342 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
| 343 |
+
tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
|
| 344 |
+
break;
|
| 345 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
| 346 |
+
tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
|
| 347 |
+
break;
|
| 348 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
| 349 |
+
tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
|
| 350 |
+
break;
|
| 351 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
| 352 |
+
tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
|
| 353 |
+
break;
|
| 354 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
| 355 |
+
tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
|
| 356 |
+
break;
|
| 357 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
| 358 |
+
tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
|
| 359 |
+
break;
|
| 360 |
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
| 361 |
+
tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
|
| 362 |
+
break;
|
| 363 |
+
default:
|
| 364 |
+
ORT_CXX_API_THROW("unknown input type", ORT_RUNTIME_EXCEPTION);
|
| 365 |
+
break;
|
| 366 |
+
}
|
| 367 |
+
tensors_.emplace_back(tensor.release());
|
| 368 |
+
} // for
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
template <typename T>
|
| 372 |
+
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
|
| 373 |
+
// ith_output is the indice of output relative to the tensor array
|
| 374 |
+
// indice_ + ith_output is the indice relative to context
|
| 375 |
+
auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
|
| 376 |
+
auto raw_output = tensor.get()->Allocate(shape);
|
| 377 |
+
tensors_.emplace_back(tensor.release());
|
| 378 |
+
return raw_output;
|
| 379 |
+
}
|
| 380 |
+
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
|
| 381 |
+
// ith_output is the indice of output relative to the tensor array
|
| 382 |
+
// indice_ + ith_output is the indice relative to context
|
| 383 |
+
auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
|
| 384 |
+
Tensor<std::string>& output = *tensor;
|
| 385 |
+
tensors_.emplace_back(tensor.release());
|
| 386 |
+
return output;
|
| 387 |
+
}
|
| 388 |
+
size_t Size() const {
|
| 389 |
+
return tensors_.size();
|
| 390 |
+
}
|
| 391 |
+
const TensorPtr& operator[](size_t ith_input) const {
|
| 392 |
+
// ith_input is the indice of output relative to the tensor array
|
| 393 |
+
return tensors_.at(ith_input);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
private:
|
| 397 |
+
TensorPtrs tensors_;
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
using Variadic = TensorArray;
|
| 401 |
+
|
| 402 |
+
/*
|
| 403 |
+
Note:
|
| 404 |
+
OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
|
| 405 |
+
The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
|
| 406 |
+
1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy.
|
| 407 |
+
2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
|
| 408 |
+
hence memory could still be recycled properly.
|
| 409 |
+
Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
|
| 410 |
+
*/
|
| 411 |
+
struct OrtLiteCustomOp : public OrtCustomOp {
|
| 412 |
+
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
|
| 413 |
+
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
|
| 414 |
+
|
| 415 |
+
// CreateTuple
|
| 416 |
+
template <size_t ith_input, size_t ith_output, typename... Ts>
|
| 417 |
+
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
|
| 418 |
+
CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
|
| 419 |
+
return std::make_tuple();
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 423 |
+
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
|
| 424 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 425 |
+
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
|
| 426 |
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 427 |
+
return std::tuple_cat(current, next);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 431 |
+
static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
|
| 432 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 433 |
+
std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
|
| 434 |
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 435 |
+
return std::tuple_cat(current, next);
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
#ifdef ORT_CUDA_CTX
|
| 439 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 440 |
+
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
|
| 441 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 442 |
+
thread_local CudaContext cuda_context;
|
| 443 |
+
cuda_context.Init(*context);
|
| 444 |
+
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
|
| 445 |
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 446 |
+
return std::tuple_cat(current, next);
|
| 447 |
+
}
|
| 448 |
+
#endif
|
| 449 |
+
|
| 450 |
+
#ifdef ORT_ROCM_CTX
|
| 451 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 452 |
+
static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
|
| 453 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 454 |
+
thread_local RocmContext rocm_context;
|
| 455 |
+
rocm_context.Init(*context);
|
| 456 |
+
std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
|
| 457 |
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 458 |
+
return std::tuple_cat(current, next);
|
| 459 |
+
}
|
| 460 |
+
#endif
|
| 461 |
+
|
| 462 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 463 |
+
static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
|
| 464 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 465 |
+
args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
|
| 466 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
|
| 467 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 468 |
+
return std::tuple_cat(current, next);
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 472 |
+
static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
|
| 473 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 474 |
+
args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
|
| 475 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
|
| 476 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
| 477 |
+
return std::tuple_cat(current, next);
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 481 |
+
static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
|
| 482 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 483 |
+
args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
|
| 484 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
|
| 485 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
|
| 486 |
+
return std::tuple_cat(current, next);
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
| 490 |
+
static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
|
| 491 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
| 492 |
+
args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
|
| 493 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
|
| 494 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
|
| 495 |
+
return std::tuple_cat(current, next);
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
#define CREATE_TUPLE_INPUT(data_type) \
|
| 499 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 500 |
+
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
| 501 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 502 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 503 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
|
| 504 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 505 |
+
return std::tuple_cat(current, next); \
|
| 506 |
+
} \
|
| 507 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 508 |
+
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
| 509 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 510 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 511 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
|
| 512 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 513 |
+
return std::tuple_cat(current, next); \
|
| 514 |
+
} \
|
| 515 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 516 |
+
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
| 517 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 518 |
+
if (ith_input < num_input) { \
|
| 519 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 520 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
|
| 521 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 522 |
+
return std::tuple_cat(current, next); \
|
| 523 |
+
} else { \
|
| 524 |
+
std::tuple<T> current = std::tuple<T>{}; \
|
| 525 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 526 |
+
return std::tuple_cat(current, next); \
|
| 527 |
+
} \
|
| 528 |
+
} \
|
| 529 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 530 |
+
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
| 531 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 532 |
+
if ("CPUExecutionProvider" != ep) { \
|
| 533 |
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
| 534 |
+
} \
|
| 535 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 536 |
+
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
| 537 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 538 |
+
return std::tuple_cat(current, next); \
|
| 539 |
+
} \
|
| 540 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 541 |
+
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
| 542 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 543 |
+
if ("CPUExecutionProvider" != ep) { \
|
| 544 |
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
| 545 |
+
} \
|
| 546 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 547 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
| 548 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 549 |
+
return std::tuple_cat(current, next); \
|
| 550 |
+
} \
|
| 551 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 552 |
+
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
| 553 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 554 |
+
if (ith_input < num_input) { \
|
| 555 |
+
if ("CPUExecutionProvider" != ep) { \
|
| 556 |
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
| 557 |
+
} \
|
| 558 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 559 |
+
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
| 560 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 561 |
+
return std::tuple_cat(current, next); \
|
| 562 |
+
} else { \
|
| 563 |
+
std::tuple<T> current = std::tuple<T>{}; \
|
| 564 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 565 |
+
return std::tuple_cat(current, next); \
|
| 566 |
+
} \
|
| 567 |
+
} \
|
| 568 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 569 |
+
static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
|
| 570 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 571 |
+
if ("CPUExecutionProvider" != ep) { \
|
| 572 |
+
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
| 573 |
+
} \
|
| 574 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 575 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
|
| 576 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 577 |
+
return std::tuple_cat(current, next); \
|
| 578 |
+
} \
|
| 579 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 580 |
+
static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
|
| 581 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 582 |
+
if (ith_input < num_input) { \
|
| 583 |
+
if ("CPUExecutionProvider" != ep) { \
|
| 584 |
+
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
| 585 |
+
} \
|
| 586 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
| 587 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
|
| 588 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 589 |
+
return std::tuple_cat(current, next); \
|
| 590 |
+
} else { \
|
| 591 |
+
std::tuple<T> current = std::tuple<T>{}; \
|
| 592 |
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
| 593 |
+
return std::tuple_cat(current, next); \
|
| 594 |
+
} \
|
| 595 |
+
}
|
| 596 |
+
#define CREATE_TUPLE_OUTPUT(data_type) \
|
| 597 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 598 |
+
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
| 599 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 600 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
| 601 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
|
| 602 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
| 603 |
+
return std::tuple_cat(current, next); \
|
| 604 |
+
} \
|
| 605 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 606 |
+
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
| 607 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 608 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
| 609 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
|
| 610 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
| 611 |
+
return std::tuple_cat(current, next); \
|
| 612 |
+
} \
|
| 613 |
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
| 614 |
+
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
| 615 |
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
| 616 |
+
if (ith_output < num_output) { \
|
| 617 |
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
| 618 |
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
|
| 619 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
| 620 |
+
return std::tuple_cat(current, next); \
|
| 621 |
+
} else { \
|
| 622 |
+
std::tuple<T> current = std::tuple<T>{}; \
|
| 623 |
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
| 624 |
+
return std::tuple_cat(current, next); \
|
| 625 |
+
} \
|
| 626 |
+
}
|
| 627 |
+
#define CREATE_TUPLE(data_type) \
|
| 628 |
+
CREATE_TUPLE_INPUT(data_type) \
|
| 629 |
+
CREATE_TUPLE_OUTPUT(data_type)
|
| 630 |
+
|
| 631 |
+
CREATE_TUPLE(bool)
|
| 632 |
+
CREATE_TUPLE(float)
|
| 633 |
+
CREATE_TUPLE(Ort::Float16_t)
|
| 634 |
+
CREATE_TUPLE(Ort::BFloat16_t)
|
| 635 |
+
CREATE_TUPLE(double)
|
| 636 |
+
CREATE_TUPLE(int8_t)
|
| 637 |
+
CREATE_TUPLE(int16_t)
|
| 638 |
+
CREATE_TUPLE(int32_t)
|
| 639 |
+
CREATE_TUPLE(int64_t)
|
| 640 |
+
CREATE_TUPLE(uint8_t)
|
| 641 |
+
CREATE_TUPLE(uint16_t)
|
| 642 |
+
CREATE_TUPLE(uint32_t)
|
| 643 |
+
CREATE_TUPLE(uint64_t)
|
| 644 |
+
CREATE_TUPLE(std::string)
|
| 645 |
+
CREATE_TUPLE_INPUT(std::string_view)
|
| 646 |
+
CREATE_TUPLE(Ort::Float8E4M3FN_t)
|
| 647 |
+
CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
|
| 648 |
+
CREATE_TUPLE(Ort::Float8E5M2_t)
|
| 649 |
+
CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
|
| 650 |
+
|
| 651 |
+
// ParseArgs ...
|
| 652 |
+
template <typename... Ts>
|
| 653 |
+
static typename std::enable_if<0 == sizeof...(Ts)>::type
|
| 654 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
template <typename T, typename... Ts>
|
| 658 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
|
| 659 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 660 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
template <typename T, typename... Ts>
|
| 664 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
|
| 665 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 666 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
#ifdef ORT_CUDA_CTX
|
| 670 |
+
template <typename T, typename... Ts>
|
| 671 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
|
| 672 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 673 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 674 |
+
}
|
| 675 |
+
#endif
|
| 676 |
+
|
| 677 |
+
#ifdef ORT_ROCM_CTX
|
| 678 |
+
template <typename T, typename... Ts>
|
| 679 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
|
| 680 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 681 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 682 |
+
}
|
| 683 |
+
#endif
|
| 684 |
+
|
| 685 |
+
template <typename T, typename... Ts>
|
| 686 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
|
| 687 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 688 |
+
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
| 689 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
template <typename T, typename... Ts>
|
| 693 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
|
| 694 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 695 |
+
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
| 696 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
template <typename T, typename... Ts>
|
| 700 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
|
| 701 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 702 |
+
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
| 703 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
template <typename T, typename... Ts>
|
| 707 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
|
| 708 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
| 709 |
+
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
| 710 |
+
ParseArgs<Ts...>(input_types, output_types);
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
|
| 714 |
+
template <typename T, typename... Ts> \
|
| 715 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
|
| 716 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 717 |
+
input_types.push_back(onnx_type); \
|
| 718 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 719 |
+
} \
|
| 720 |
+
template <typename T, typename... Ts> \
|
| 721 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
|
| 722 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 723 |
+
input_types.push_back(onnx_type); \
|
| 724 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 725 |
+
} \
|
| 726 |
+
template <typename T, typename... Ts> \
|
| 727 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
|
| 728 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 729 |
+
input_types.push_back(onnx_type); \
|
| 730 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
#define PARSE_INPUT(data_type, onnx_type) \
|
| 734 |
+
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
|
| 735 |
+
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
|
| 736 |
+
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
|
| 737 |
+
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
|
| 738 |
+
PARSE_INPUT_BASE(data_type, onnx_type)
|
| 739 |
+
|
| 740 |
+
#define PARSE_OUTPUT(data_type, onnx_type) \
|
| 741 |
+
template <typename T, typename... Ts> \
|
| 742 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
|
| 743 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 744 |
+
output_types.push_back(onnx_type); \
|
| 745 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 746 |
+
} \
|
| 747 |
+
template <typename T, typename... Ts> \
|
| 748 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
|
| 749 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 750 |
+
output_types.push_back(onnx_type); \
|
| 751 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 752 |
+
} \
|
| 753 |
+
template <typename T, typename... Ts> \
|
| 754 |
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
|
| 755 |
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
| 756 |
+
output_types.push_back(onnx_type); \
|
| 757 |
+
ParseArgs<Ts...>(input_types, output_types); \
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
#define PARSE_ARGS(data_type, onnx_type) \
|
| 761 |
+
PARSE_INPUT(data_type, onnx_type) \
|
| 762 |
+
PARSE_OUTPUT(data_type, onnx_type)
|
| 763 |
+
|
| 764 |
+
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
|
| 765 |
+
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
|
| 766 |
+
PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
|
| 767 |
+
PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
|
| 768 |
+
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
|
| 769 |
+
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
|
| 770 |
+
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
|
| 771 |
+
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
|
| 772 |
+
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
|
| 773 |
+
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
|
| 774 |
+
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
|
| 775 |
+
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
|
| 776 |
+
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
|
| 777 |
+
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
|
| 778 |
+
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
|
| 779 |
+
PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
|
| 780 |
+
PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
|
| 781 |
+
PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
|
| 782 |
+
PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
|
| 783 |
+
|
| 784 |
+
OrtLiteCustomOp(const char* op_name,
|
| 785 |
+
const char* execution_provider,
|
| 786 |
+
ShapeInferFn shape_infer_fn,
|
| 787 |
+
int start_ver = 1,
|
| 788 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
|
| 789 |
+
execution_provider_(execution_provider),
|
| 790 |
+
shape_infer_fn_(shape_infer_fn),
|
| 791 |
+
start_ver_(start_ver),
|
| 792 |
+
end_ver_(end_ver) {
|
| 793 |
+
OrtCustomOp::version = ORT_API_VERSION;
|
| 794 |
+
|
| 795 |
+
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
|
| 796 |
+
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
|
| 797 |
+
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
|
| 798 |
+
|
| 799 |
+
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
|
| 800 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 801 |
+
return self->input_types_.size();
|
| 802 |
+
};
|
| 803 |
+
|
| 804 |
+
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
|
| 805 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 806 |
+
return self->input_types_[indice];
|
| 807 |
+
};
|
| 808 |
+
|
| 809 |
+
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
|
| 810 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 811 |
+
return self->output_types_.size();
|
| 812 |
+
};
|
| 813 |
+
|
| 814 |
+
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
|
| 815 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 816 |
+
return self->output_types_[indice];
|
| 817 |
+
};
|
| 818 |
+
|
| 819 |
+
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
|
| 820 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 821 |
+
return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
|
| 822 |
+
};
|
| 823 |
+
|
| 824 |
+
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
|
| 825 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 826 |
+
return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
|
| 827 |
+
};
|
| 828 |
+
|
| 829 |
+
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
|
| 830 |
+
return 1;
|
| 831 |
+
};
|
| 832 |
+
|
| 833 |
+
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
|
| 834 |
+
return 0;
|
| 835 |
+
};
|
| 836 |
+
|
| 837 |
+
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
|
| 838 |
+
return 1;
|
| 839 |
+
};
|
| 840 |
+
|
| 841 |
+
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
|
| 842 |
+
return 0;
|
| 843 |
+
};
|
| 844 |
+
|
| 845 |
+
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
|
| 846 |
+
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
| 847 |
+
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
|
| 848 |
+
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
| 849 |
+
|
| 850 |
+
OrtCustomOp::CreateKernelV2 = {};
|
| 851 |
+
OrtCustomOp::KernelComputeV2 = {};
|
| 852 |
+
OrtCustomOp::KernelCompute = {};
|
| 853 |
+
|
| 854 |
+
OrtCustomOp::InferOutputShapeFn = {};
|
| 855 |
+
|
| 856 |
+
OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
|
| 857 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 858 |
+
return self->start_ver_;
|
| 859 |
+
};
|
| 860 |
+
|
| 861 |
+
OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
|
| 862 |
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
| 863 |
+
return self->end_ver_;
|
| 864 |
+
};
|
| 865 |
+
|
| 866 |
+
OrtCustomOp::GetMayInplace = {};
|
| 867 |
+
OrtCustomOp::ReleaseMayInplace = {};
|
| 868 |
+
OrtCustomOp::GetAliasMap = {};
|
| 869 |
+
OrtCustomOp::ReleaseAliasMap = {};
|
| 870 |
+
}
|
| 871 |
+
|
| 872 |
+
const std::string op_name_;
|
| 873 |
+
const std::string execution_provider_;
|
| 874 |
+
|
| 875 |
+
std::vector<ONNXTensorElementDataType> input_types_;
|
| 876 |
+
std::vector<ONNXTensorElementDataType> output_types_;
|
| 877 |
+
|
| 878 |
+
ShapeInferFn shape_infer_fn_ = {};
|
| 879 |
+
|
| 880 |
+
int start_ver_ = 1;
|
| 881 |
+
int end_ver_ = MAX_CUSTOM_OP_END_VER;
|
| 882 |
+
|
| 883 |
+
void* compute_fn_ = {};
|
| 884 |
+
void* compute_fn_return_status_ = {};
|
| 885 |
+
};
|
| 886 |
+
|
| 887 |
+
//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
|
| 888 |
+
// The struct is to implement function-as-op.
|
| 889 |
+
// E.g. a function might be defined as:
|
| 890 |
+
// void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
|
| 891 |
+
// It could be registered this way:
|
| 892 |
+
// Ort::CustomOpDomain v2_domain{"v2"};
|
| 893 |
+
// std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
|
| 894 |
+
// v2_domain.Add(fil_op_ptr.get());
|
| 895 |
+
// session_options.Add(v2_domain);
|
| 896 |
+
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
| 897 |
+
template <typename... Args>
|
| 898 |
+
struct OrtLiteCustomFunc : public OrtLiteCustomOp {
|
| 899 |
+
using ComputeFn = void (*)(Args...);
|
| 900 |
+
using ComputeFnReturnStatus = Status (*)(Args...);
|
| 901 |
+
using MyType = OrtLiteCustomFunc<Args...>;
|
| 902 |
+
|
| 903 |
+
struct Kernel {
|
| 904 |
+
size_t num_input_{};
|
| 905 |
+
size_t num_output_{};
|
| 906 |
+
ComputeFn compute_fn_{};
|
| 907 |
+
ComputeFnReturnStatus compute_fn_return_status_{};
|
| 908 |
+
std::string ep_{};
|
| 909 |
+
};
|
| 910 |
+
|
| 911 |
+
OrtLiteCustomFunc(const char* op_name,
|
| 912 |
+
const char* execution_provider,
|
| 913 |
+
ComputeFn compute_fn,
|
| 914 |
+
ShapeInferFn shape_infer_fn = {},
|
| 915 |
+
int start_ver = 1,
|
| 916 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
|
| 917 |
+
compute_fn_ = reinterpret_cast<void*>(compute_fn);
|
| 918 |
+
ParseArgs<Args...>(input_types_, output_types_);
|
| 919 |
+
|
| 920 |
+
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
| 921 |
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
| 922 |
+
std::vector<ArgPtr> args;
|
| 923 |
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
| 924 |
+
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
|
| 925 |
+
};
|
| 926 |
+
|
| 927 |
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
| 928 |
+
auto kernel = std::make_unique<Kernel>();
|
| 929 |
+
auto me = static_cast<const MyType*>(this_);
|
| 930 |
+
kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
|
| 931 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
| 932 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
| 933 |
+
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
|
| 934 |
+
kernel->ep_ = self->execution_provider_;
|
| 935 |
+
return reinterpret_cast<void*>(kernel.release());
|
| 936 |
+
};
|
| 937 |
+
|
| 938 |
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
| 939 |
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
| 940 |
+
};
|
| 941 |
+
|
| 942 |
+
if (shape_infer_fn_) {
|
| 943 |
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
| 944 |
+
auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
|
| 945 |
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
| 946 |
+
return shape_info_fn(ctx);
|
| 947 |
+
};
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
OrtLiteCustomFunc(const char* op_name,
|
| 952 |
+
const char* execution_provider,
|
| 953 |
+
ComputeFnReturnStatus compute_fn_return_status,
|
| 954 |
+
ShapeInferFn shape_infer_fn = {},
|
| 955 |
+
int start_ver = 1,
|
| 956 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
|
| 957 |
+
compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
|
| 958 |
+
ParseArgs<Args...>(input_types_, output_types_);
|
| 959 |
+
|
| 960 |
+
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
| 961 |
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
| 962 |
+
std::vector<ArgPtr> args;
|
| 963 |
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
| 964 |
+
return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
|
| 965 |
+
};
|
| 966 |
+
|
| 967 |
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
| 968 |
+
auto kernel = std::make_unique<Kernel>();
|
| 969 |
+
auto me = static_cast<const MyType*>(this_);
|
| 970 |
+
kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
|
| 971 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
| 972 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
| 973 |
+
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
|
| 974 |
+
kernel->ep_ = self->execution_provider_;
|
| 975 |
+
return reinterpret_cast<void*>(kernel.release());
|
| 976 |
+
};
|
| 977 |
+
|
| 978 |
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
| 979 |
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
| 980 |
+
};
|
| 981 |
+
|
| 982 |
+
if (shape_infer_fn_) {
|
| 983 |
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
| 984 |
+
auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
|
| 985 |
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
| 986 |
+
return shape_info_fn(ctx);
|
| 987 |
+
};
|
| 988 |
+
}
|
| 989 |
+
}
|
| 990 |
+
}; // struct OrtLiteCustomFunc
|
| 991 |
+
|
| 992 |
+
/////////////////////////// OrtLiteCustomStruct ///////////////////////////
|
| 993 |
+
// The struct is to implement struct-as-op.
|
| 994 |
+
// E.g. a struct might be defined as:
|
| 995 |
+
// struct Merge {
|
| 996 |
+
// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
|
| 997 |
+
// void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
|
| 998 |
+
// std::string_view string_in,
|
| 999 |
+
// Ort::Custom::Tensor<std::string>* strings_out) {...}
|
| 1000 |
+
// bool reverse_ = false;
|
| 1001 |
+
// };
|
| 1002 |
+
// It could be registered this way:
|
| 1003 |
+
// Ort::CustomOpDomain v2_domain{"v2"};
|
| 1004 |
+
// std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
|
| 1005 |
+
// v2_domain.Add(mrg_op_ptr.get());
|
| 1006 |
+
// session_options.Add(v2_domain);
|
| 1007 |
+
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
| 1008 |
+
template <typename CustomOp>
|
| 1009 |
+
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
| 1010 |
+
template <typename... Args>
|
| 1011 |
+
using CustomComputeFn = void (CustomOp::*)(Args...);
|
| 1012 |
+
|
| 1013 |
+
template <typename... Args>
|
| 1014 |
+
using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
|
| 1015 |
+
|
| 1016 |
+
using MyType = OrtLiteCustomStruct<CustomOp>;
|
| 1017 |
+
|
| 1018 |
+
struct Kernel {
|
| 1019 |
+
size_t num_input_{};
|
| 1020 |
+
size_t num_output_{};
|
| 1021 |
+
std::unique_ptr<CustomOp> custom_op_;
|
| 1022 |
+
std::string ep_{};
|
| 1023 |
+
};
|
| 1024 |
+
|
| 1025 |
+
OrtLiteCustomStruct(const char* op_name,
|
| 1026 |
+
const char* execution_provider,
|
| 1027 |
+
int start_ver = 1,
|
| 1028 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
|
| 1029 |
+
SetCompute(&CustomOp::Compute);
|
| 1030 |
+
|
| 1031 |
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
| 1032 |
+
auto kernel = std::make_unique<Kernel>();
|
| 1033 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
| 1034 |
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
| 1035 |
+
kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
|
| 1036 |
+
auto self = static_cast<const OrtLiteCustomStruct*>(this_);
|
| 1037 |
+
kernel->ep_ = self->execution_provider_;
|
| 1038 |
+
return reinterpret_cast<void*>(kernel.release());
|
| 1039 |
+
};
|
| 1040 |
+
|
| 1041 |
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
| 1042 |
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
| 1043 |
+
};
|
| 1044 |
+
|
| 1045 |
+
SetShapeInfer<CustomOp>(0);
|
| 1046 |
+
}
|
| 1047 |
+
|
| 1048 |
+
template <typename... Args>
|
| 1049 |
+
void SetCompute(CustomComputeFn<Args...>) {
|
| 1050 |
+
ParseArgs<Args...>(input_types_, output_types_);
|
| 1051 |
+
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
| 1052 |
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
| 1053 |
+
ArgPtrs args;
|
| 1054 |
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
| 1055 |
+
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
|
| 1056 |
+
};
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
template <typename... Args>
|
| 1060 |
+
void SetCompute(CustomComputeFnReturnStatus<Args...>) {
|
| 1061 |
+
ParseArgs<Args...>(input_types_, output_types_);
|
| 1062 |
+
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
| 1063 |
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
| 1064 |
+
ArgPtrs args;
|
| 1065 |
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
| 1066 |
+
return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
|
| 1067 |
+
};
|
| 1068 |
+
}
|
| 1069 |
+
|
| 1070 |
+
template <typename C>
|
| 1071 |
+
decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
|
| 1072 |
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
| 1073 |
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
| 1074 |
+
return C::InferOutputShape(ctx);
|
| 1075 |
+
};
|
| 1076 |
+
return {};
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
template <typename C>
|
| 1080 |
+
void SetShapeInfer(...) {
|
| 1081 |
+
OrtCustomOp::InferOutputShapeFn = {};
|
| 1082 |
+
}
|
| 1083 |
+
}; // struct OrtLiteCustomStruct
|
| 1084 |
+
|
| 1085 |
+
/////////////////////////// CreateLiteCustomOp ////////////////////////////
|
| 1086 |
+
|
| 1087 |
+
template <typename... Args>
|
| 1088 |
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
| 1089 |
+
const char* execution_provider,
|
| 1090 |
+
void (*custom_compute_fn)(Args...),
|
| 1091 |
+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
|
| 1092 |
+
int start_ver = 1,
|
| 1093 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
| 1094 |
+
using LiteOp = OrtLiteCustomFunc<Args...>;
|
| 1095 |
+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
template <typename... Args>
|
| 1099 |
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
| 1100 |
+
const char* execution_provider,
|
| 1101 |
+
Status (*custom_compute_fn_v2)(Args...),
|
| 1102 |
+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
|
| 1103 |
+
int start_ver = 1,
|
| 1104 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
| 1105 |
+
using LiteOp = OrtLiteCustomFunc<Args...>;
|
| 1106 |
+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
template <typename CustomOp>
|
| 1110 |
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
| 1111 |
+
const char* execution_provider,
|
| 1112 |
+
int start_ver = 1,
|
| 1113 |
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
| 1114 |
+
using LiteOp = OrtLiteCustomStruct<CustomOp>;
|
| 1115 |
+
return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
} // namespace Custom
|
| 1119 |
+
} // namespace Ort
|
v1.23.1/headers/onnxruntime_run_options_config_keys.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
* This file defines RunOptions Config Keys and format of the Config Values.
|
| 8 |
+
*
|
| 9 |
+
* The Naming Convention for a RunOptions Config Key,
|
| 10 |
+
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
| 11 |
+
* Such as "ep.cuda.use_arena"
|
| 12 |
+
* The Config Key cannot be empty
|
| 13 |
+
* The maximum length of the Config Key is 128
|
| 14 |
+
*
|
| 15 |
+
* The string format of a RunOptions Config Value is defined individually for each Config.
|
| 16 |
+
* The maximum length of the Config Value is 1024
|
| 17 |
+
*/
|
| 18 |
+
|
| 19 |
+
// Key for enabling shrinkages of user listed device memory arenas.
|
| 20 |
+
// Expects a list of semi-colon separated key value pairs separated by colon in the following format:
|
| 21 |
+
// "device_0:device_id_0;device_1:device_id_1"
|
| 22 |
+
// No white-spaces allowed in the provided list string.
|
| 23 |
+
// Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
|
| 24 |
+
// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
|
| 25 |
+
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
|
| 26 |
+
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
|
| 27 |
+
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
|
| 28 |
+
|
| 29 |
+
// Set to '1' to not synchronize execution providers with CPU at the end of session run.
|
| 30 |
+
// Per default it will be set to '0'
|
| 31 |
+
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
|
| 32 |
+
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
|
| 33 |
+
|
| 34 |
+
// Set HTP performance mode for QNN HTP backend before session run.
|
| 35 |
+
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
|
| 36 |
+
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
|
| 37 |
+
// "sustained_high_performance". Default to "default".
|
| 38 |
+
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
|
| 39 |
+
|
| 40 |
+
// Set HTP performance mode for QNN HTP backend post session run.
|
| 41 |
+
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
|
| 42 |
+
|
| 43 |
+
// Set RPC control latency for QNN HTP backend
|
| 44 |
+
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
|
| 45 |
+
|
| 46 |
+
// Set QNN Lora Config File for apply Lora in QNN context binary
|
| 47 |
+
static const char* const kOrtRunOptionsConfigQnnLoraConfig = "qnn.lora_config";
|
| 48 |
+
|
| 49 |
+
// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
|
| 50 |
+
// The value should be an integer. If the value is not set, the default value is 0 and
|
| 51 |
+
// ORT session only captures one cuda graph before another capture is requested.
|
| 52 |
+
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
|
| 53 |
+
// User are not expected to set the value to 0 as it is reserved for internal use.
|
| 54 |
+
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
|
v1.23.1/headers/onnxruntime_session_options_config_keys.h
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
* This file defines SessionOptions Config Keys and format of the Config Values.
|
| 8 |
+
*
|
| 9 |
+
* The Naming Convention for a SessionOptions Config Key,
|
| 10 |
+
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
| 11 |
+
* Such as "ep.cuda.use_arena"
|
| 12 |
+
* The Config Key cannot be empty
|
| 13 |
+
* The maximum length of the Config Key is 1024
|
| 14 |
+
*
|
| 15 |
+
* The string format of a SessionOptions Config Value is defined individually for each Config.
|
| 16 |
+
* The maximum length of the Config Value is 2048
|
| 17 |
+
*/
|
| 18 |
+
|
| 19 |
+
// Key for disable PrePacking,
|
| 20 |
+
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
|
| 21 |
+
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
|
| 22 |
+
|
| 23 |
+
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
|
| 24 |
+
// will be used. Use this to override the usage of env allocators on a per session level.
|
| 25 |
+
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
|
| 26 |
+
|
| 27 |
+
// Set to 'ORT' (case sensitive) to load an ORT format model.
|
| 28 |
+
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
|
| 29 |
+
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
|
| 30 |
+
|
| 31 |
+
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
|
| 32 |
+
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
|
| 33 |
+
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
|
| 34 |
+
|
| 35 |
+
// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
|
| 36 |
+
// When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
|
| 37 |
+
// but threads in session thread pools follow option changes.
|
| 38 |
+
// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
|
| 39 |
+
// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
|
| 40 |
+
// Note that an alternative way not using this option at runtime is to train and export a model without denormals
|
| 41 |
+
// and that's recommended because turning this option on may hurt model accuracy.
|
| 42 |
+
static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
|
| 43 |
+
|
| 44 |
+
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
|
| 45 |
+
// "0": enable. ORT does fusion logic for QDQ format.
|
| 46 |
+
// "1": disable. ORT doesn't do fusion logic for QDQ format.
|
| 47 |
+
// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
|
| 48 |
+
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
|
| 49 |
+
|
| 50 |
+
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
|
| 51 |
+
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
| 52 |
+
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
| 53 |
+
// Its default value is "0"
|
| 54 |
+
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
|
| 55 |
+
|
| 56 |
+
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
|
| 57 |
+
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
|
| 58 |
+
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
|
| 59 |
+
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
|
| 60 |
+
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
|
| 61 |
+
// As such, it's best to test to determine if enabling this works well for your scenario.
|
| 62 |
+
// The default value is "0"
|
| 63 |
+
// Available since version 1.11.
|
| 64 |
+
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
|
| 65 |
+
|
| 66 |
+
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
|
| 67 |
+
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
|
| 68 |
+
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
|
| 69 |
+
|
| 70 |
+
// Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0".
|
| 71 |
+
// CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this.
|
| 72 |
+
static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination";
|
| 73 |
+
|
| 74 |
+
// This setting controls whether to enable AheadOfTime function inlining.
|
| 75 |
+
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
|
| 76 |
+
// as possible with the help of enabled execution providers.
|
| 77 |
+
// This can reduce the number of function calls and improve performance because it is done before
|
| 78 |
+
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
|
| 79 |
+
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
|
| 80 |
+
// "0": enable; "1": disable.
|
| 81 |
+
// Its default value is "0".
|
| 82 |
+
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
|
| 83 |
+
|
| 84 |
+
#ifdef ENABLE_TRAINING
|
| 85 |
+
// Specifies a path of the file containing a list of memory optimization configurations.
|
| 86 |
+
// The value should be a string indicating the file path of the config file.
|
| 87 |
+
// The content of the config file is a JSON struct like this:
|
| 88 |
+
// [
|
| 89 |
+
// "Gelu+Cast+:1:0",
|
| 90 |
+
// "Dropout+:1:1"
|
| 91 |
+
// ]
|
| 92 |
+
// Taking the example of "Gelu+Cast+:1:0",
|
| 93 |
+
// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation
|
| 94 |
+
// output by ORT graph transformations.
|
| 95 |
+
// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute.
|
| 96 |
+
// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization,
|
| 97 |
+
// to avoid "oversaving" the memory.
|
| 98 |
+
static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config";
|
| 99 |
+
|
| 100 |
+
// Specifies the config for detecting subgraphs for memory footprint reduction.
|
| 101 |
+
// The value should be a string contains int separated using commas. The default value is "0:0".
|
| 102 |
+
static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
|
| 103 |
+
#endif
|
| 104 |
+
|
| 105 |
+
// This setting if set should contain a comma separated list of optimizers names that should be disabled.
|
| 106 |
+
// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer
|
| 107 |
+
// does not provider runtime benefits, but affects your model loading time you may disable it using this config
|
| 108 |
+
// entry. This option is not enabled in ORT_MINIMAL_BUILD build.
|
| 109 |
+
// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc
|
| 110 |
+
//
|
| 111 |
+
// Default is an empty string which means no optimizers are disabled.
|
| 112 |
+
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
|
| 113 |
+
|
| 114 |
+
// It controls whether to run graph optimizations in loop or not.
|
| 115 |
+
//
|
| 116 |
+
// "0": disable. Graph Optimization Loop is disabled.
|
| 117 |
+
// ```
|
| 118 |
+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
|
| 119 |
+
// ^ |
|
| 120 |
+
// | "No Loop" |
|
| 121 |
+
// | |
|
| 122 |
+
// X xxxxxxxxxxx X
|
| 123 |
+
// ```
|
| 124 |
+
// "1": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 4 are applied then
|
| 125 |
+
// the loop will check for any other valid optimization that can happen.
|
| 126 |
+
// ```
|
| 127 |
+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
|
| 128 |
+
// ^ |
|
| 129 |
+
// | "Loop only depending on Level 4" |
|
| 130 |
+
// | |
|
| 131 |
+
// ---------------------------------------------------
|
| 132 |
+
// ```
|
| 133 |
+
// "2": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 2 or above are applied then
|
| 134 |
+
// The loop will check for any other valid optimization that can happen.
|
| 135 |
+
// ```
|
| 136 |
+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
|
| 137 |
+
// ^ |
|
| 138 |
+
// | "Loop" |
|
| 139 |
+
// | |
|
| 140 |
+
// ---------------------------------------------------
|
| 141 |
+
// ```
|
| 142 |
+
// Default value is set to "1".
|
| 143 |
+
static const char* const kOrtSessionOptionsGraphOptimizationsLoopLevel = "session.graph_optimizations_loop_level";
|
| 144 |
+
|
| 145 |
+
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
|
| 146 |
+
// Using device allocators means the memory allocation is made using malloc/new.
|
| 147 |
+
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
|
| 148 |
+
|
| 149 |
+
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
|
| 150 |
+
// "0": thread will block if found no job to run
|
| 151 |
+
// "1": thread will spin a number of times before blocking
|
| 152 |
+
// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise.
|
| 153 |
+
// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency.
|
| 154 |
+
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
|
| 155 |
+
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
|
| 156 |
+
|
| 157 |
+
// Key for using model bytes directly for ORT format
|
| 158 |
+
// If a session is created using an input byte array contains the ORT format model data,
|
| 159 |
+
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
|
| 160 |
+
// buffer is valid.
|
| 161 |
+
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
|
| 162 |
+
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
|
| 163 |
+
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
|
| 164 |
+
|
| 165 |
+
/// <summary>
|
| 166 |
+
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
|
| 167 |
+
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
|
| 168 |
+
/// Requires `session.use_ort_model_bytes_directly` to be true.
|
| 169 |
+
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
|
| 170 |
+
/// duration of the InferenceSession.
|
| 171 |
+
/// </summary>
|
| 172 |
+
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
|
| 173 |
+
"session.use_ort_model_bytes_for_initializers";
|
| 174 |
+
|
| 175 |
+
// This should only be specified when exporting an ORT format model for use on a different platform.
|
| 176 |
+
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
|
| 177 |
+
// Available since version 1.11.
|
| 178 |
+
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
|
| 179 |
+
|
| 180 |
+
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
|
| 181 |
+
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
|
| 182 |
+
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
|
| 183 |
+
// platforms.
|
| 184 |
+
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
|
| 185 |
+
|
| 186 |
+
// Specifies how minimal build graph optimizations are handled in a full build.
|
| 187 |
+
// These optimizations are at the extended level or higher.
|
| 188 |
+
// Possible values and their effects are:
|
| 189 |
+
// "save": Save runtime optimizations when saving an ORT format model.
|
| 190 |
+
// "apply": Only apply optimizations available in a minimal build.
|
| 191 |
+
// ""/<unspecified>: Apply optimizations available in a full build.
|
| 192 |
+
// Available since version 1.11.
|
| 193 |
+
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
|
| 194 |
+
"optimization.minimal_build_optimizations";
|
| 195 |
+
|
| 196 |
+
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
|
| 197 |
+
// order for them to take effect.
|
| 198 |
+
|
| 199 |
+
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
|
| 200 |
+
// run by the NNAPI EP.
|
| 201 |
+
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
|
| 202 |
+
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
|
| 203 |
+
// exclusion, set the value to "".
|
| 204 |
+
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
|
| 205 |
+
|
| 206 |
+
// Enabling dynamic block-sizing for multithreading.
|
| 207 |
+
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
|
| 208 |
+
// N / (num_of_threads * dynamic_block_base)
|
| 209 |
+
// As execution progresses, the size will decrease according to the diminishing residual of N,
|
| 210 |
+
// meaning the task will be distributed in smaller granularity for better parallelism.
|
| 211 |
+
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
|
| 212 |
+
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
|
| 213 |
+
// Available since version 1.11.
|
| 214 |
+
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
|
| 215 |
+
|
| 216 |
+
// This option allows to decrease CPU usage between infrequent
|
| 217 |
+
// requests and forces any TP threads spinning stop immediately when the last of
|
| 218 |
+
// concurrent Run() call returns.
|
| 219 |
+
// Spinning is restarted on the next Run() call.
|
| 220 |
+
// Applies only to internal thread-pools
|
| 221 |
+
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
|
| 222 |
+
|
| 223 |
+
// "1": all inconsistencies encountered during shape and type inference
|
| 224 |
+
// will result in failures.
|
| 225 |
+
// "0": in some cases warnings will be logged but processing will continue. The default.
|
| 226 |
+
// May be useful to expose bugs in models.
|
| 227 |
+
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
|
| 228 |
+
|
| 229 |
+
// "1": every model using a more recent opset than the latest released one will fail
|
| 230 |
+
// "0": the model may or may not work if onnxruntime cannot find an implementation, this option
|
| 231 |
+
// is used for development purpose.
|
| 232 |
+
static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
|
| 233 |
+
|
| 234 |
+
// The file saves configuration for partitioning node among logic streams
|
| 235 |
+
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
|
| 236 |
+
|
| 237 |
+
// This Option allows setting affinities for intra op threads.
|
| 238 |
+
// Affinity string follows format:
|
| 239 |
+
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
|
| 240 |
+
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
|
| 241 |
+
// e.g.1,2,3;4,5
|
| 242 |
+
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
|
| 243 |
+
// To ease the configuration, an "interval" is also allowed:
|
| 244 |
+
// e.g. 1-8;8-16;17-24
|
| 245 |
+
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
|
| 246 |
+
// Note:
|
| 247 |
+
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
|
| 248 |
+
// is started and managed by the calling app;
|
| 249 |
+
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
|
| 250 |
+
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
|
| 251 |
+
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
|
| 252 |
+
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
|
| 253 |
+
|
| 254 |
+
// This option will dump out the model to assist debugging any issues with layout transformation,
|
| 255 |
+
// and is primarily intended for developer usage. It is only relevant if an execution provider that requests
|
| 256 |
+
// NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
|
| 257 |
+
//
|
| 258 |
+
// Default is off. Set to "1" to enable.
|
| 259 |
+
//
|
| 260 |
+
// If modified by layout transformation the model will be dumped after these steps:
|
| 261 |
+
// 1) insertion of the layout transformation Transpose nodes
|
| 262 |
+
// 2) after those are optimized using the transpose optimizer,
|
| 263 |
+
// 3) after the L1 transformers are applied to the updated graph.
|
| 264 |
+
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
|
| 265 |
+
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
|
| 266 |
+
|
| 267 |
+
// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
|
| 268 |
+
// assigned (i.e., "fallback") to the CPU EP by default.
|
| 269 |
+
//
|
| 270 |
+
// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
|
| 271 |
+
// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
|
| 272 |
+
// fully support all of the nodes in the graph.
|
| 273 |
+
//
|
| 274 |
+
// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
|
| 275 |
+
// will also fail with an error.
|
| 276 |
+
//
|
| 277 |
+
// Option values:
|
| 278 |
+
// - "0": CPU EP fallback is not disabled. [DEFAULT]
|
| 279 |
+
// - "1": CPU EP fallback is disabled.
|
| 280 |
+
static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
|
| 281 |
+
|
| 282 |
+
// Use this config when serializing a large model after optimization to specify an external initializers file
|
| 283 |
+
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
|
| 284 |
+
"session.optimized_model_external_initializers_file_name";
|
| 285 |
+
|
| 286 |
+
// Use this config to control the minimum size of the initializer when externalizing it during serialization
|
| 287 |
+
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
|
| 288 |
+
"session.optimized_model_external_initializers_min_size_in_bytes";
|
| 289 |
+
|
| 290 |
+
// When loading model from memory buffer and the model has external initializers
|
| 291 |
+
// Use this config to set the external data file folder path
|
| 292 |
+
// All external data files should be in the same folder
|
| 293 |
+
static const char* const kOrtSessionOptionsModelExternalInitializersFileFolderPath =
|
| 294 |
+
"session.model_external_initializers_file_folder_path";
|
| 295 |
+
|
| 296 |
+
// Use this config when saving pre-packed constant initializers to an external data file.
|
| 297 |
+
// This allows you to memory map pre-packed initializers on model load and leave it to
|
| 298 |
+
// to the OS the amount of memory consumed by the pre-packed initializers. Otherwise,
|
| 299 |
+
// pre-packed data resides on the heap.
|
| 300 |
+
//
|
| 301 |
+
// - "0": Default is not save pre-packed initializers to a data file.
|
| 302 |
+
// - "1": Save pre-packed constant initializers to an external data file.
|
| 303 |
+
// Sample usage: sess_options.add_session_config_entry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1")
|
| 304 |
+
static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
|
| 305 |
+
"session.save_external_prepacked_constant_initializers";
|
| 306 |
+
|
| 307 |
+
// Use this config when you want to collect memory stats for each node in the graph.
|
| 308 |
+
// The file format is a CSV file with the following columns:
|
| 309 |
+
// The file will be created if it does not exist, and will be overwritten if it does.
|
| 310 |
+
//
|
| 311 |
+
// The content of the file can be used to estimate memory requirements at run time including
|
| 312 |
+
// the temporary allocations. This operation is preferably done on a CPU device, as the model may exceed
|
| 313 |
+
// device memory limits in constrained environments. When enabling this option, it is important to disable
|
| 314 |
+
// memory patterns, as they tend to allocate large blocks to avoid fragmentation and accommodate needs of multiple
|
| 315 |
+
// kernels. Memory patterns may make it difficult to allocate on a device with limited memory.
|
| 316 |
+
//
|
| 317 |
+
// The collected stats then can be used to partition the graph among the devices in a way that only the
|
| 318 |
+
// required memory is allocated on each device.
|
| 319 |
+
//
|
| 320 |
+
// node_name, initializers_memory, dynamic_outputs_sizes, temp_allocations_size
|
| 321 |
+
//
|
| 322 |
+
// - "full path to file": there is not a default for this option. If the file can not be opened for writing, an error will be returned.
|
| 323 |
+
static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "session.collect_node_memory_stats_to_file";
|
| 324 |
+
|
| 325 |
+
/// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats"
|
| 326 |
+
/// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. `limit` is optional and when absent
|
| 327 |
+
/// the provider may attempt to figure out the memory available automatically.
|
| 328 |
+
/// The setting with no limit is expected to look like: ",file name for collected stats"
|
| 329 |
+
/// The EP will place nodes on device "file name" :
|
| 330 |
+
/// this file is expected to be found at the same folder with the model. The file contains
|
| 331 |
+
/// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above)
|
| 332 |
+
static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings =
|
| 333 |
+
"session.resource_cuda_partitioning_settings";
|
| 334 |
+
|
| 335 |
+
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
|
| 336 |
+
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
|
| 337 |
+
// "0": disable. (default)
|
| 338 |
+
// "1": enable.
|
| 339 |
+
static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
|
| 340 |
+
|
| 341 |
+
// Specify the file path for the Onnx model which has EP context.
|
| 342 |
+
// Default to original_file_name_ctx.onnx if not specified
|
| 343 |
+
// Folder is not a valid option
|
| 344 |
+
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
|
| 345 |
+
|
| 346 |
+
// Flag to specify whether to dump the EP context into the Onnx model.
|
| 347 |
+
// "0": dump the EP context into separate file, keep the file name in the Onnx model. (default).
|
| 348 |
+
// "1": dump the EP context into the Onnx model.
|
| 349 |
+
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
|
| 350 |
+
|
| 351 |
+
// Specify the EPContext node name prefix to make it unique
|
| 352 |
+
// in case user need to merge/connect multiple EPContext nodes in one model
|
| 353 |
+
static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix";
|
| 354 |
+
|
| 355 |
+
// Share EP related resources across sessions
|
| 356 |
+
static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts";
|
| 357 |
+
|
| 358 |
+
// Stop to share EP related resources across sessions from then on
|
| 359 |
+
static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts";
|
| 360 |
+
|
| 361 |
+
// Used only for context model generation.
|
| 362 |
+
// This configuration is used when some nodes are partitioned on the CPU EP and those nodes have external initializers.
|
| 363 |
+
// When generating the EP context model, the new model should not rely on the old external data file used by the source ONNX model.
|
| 364 |
+
// Use this setting when dumping the EP context model with an external initializers file.
|
| 365 |
+
// If specified, all initializers will be placed inside the external data file.
|
| 366 |
+
// Otherwise, all initializers will be embedded inside the generated ONNX file.
|
| 367 |
+
// By default, this option is not set, meaning all initializers will be included within the ONNX file.
|
| 368 |
+
static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName =
|
| 369 |
+
"ep.context_model_external_initializers_file_name";
|
| 370 |
+
|
| 371 |
+
// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
|
| 372 |
+
// Option values:
|
| 373 |
+
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
|
| 374 |
+
// - "1": Gemm FastMath mode is enabled.
|
| 375 |
+
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
|
| 376 |
+
|
| 377 |
+
// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
|
| 378 |
+
// Refer to MatMulNBits op schema for more details.
|
| 379 |
+
// If not provided, default is 4.
|
| 380 |
+
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
|
| 381 |
+
|
| 382 |
+
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
|
| 383 |
+
// Meant to be used with SetEpDynamicOptions
|
| 384 |
+
// Specify the type of workload for this session.
|
| 385 |
+
// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default]
|
| 386 |
+
// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
|
| 387 |
+
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
|
| 388 |
+
|
| 389 |
+
// Disables model compilation during session initialization.
|
| 390 |
+
//
|
| 391 |
+
// If this option is set to "1", inference session creation will fail with error code ORT_MODEL_REQUIRES_COMPILATION
|
| 392 |
+
// if compilation is required to run the model on any Execution Provider added to the session.
|
| 393 |
+
// Only the following kinds of models are valid when this option is set to "1":
|
| 394 |
+
// - Pre-compiled models that have EPContext nodes for the compiling Execution Providers in the session.
|
| 395 |
+
// - Non-compiled models that run only on non-compiling Execution Providers, like CPU EP.
|
| 396 |
+
//
|
| 397 |
+
// See \href https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for details about
|
| 398 |
+
// compiled models with EPContext nodes.
|
| 399 |
+
//
|
| 400 |
+
// Option values:
|
| 401 |
+
// - "0": EP compile is not disabled. [DEFAULT]
|
| 402 |
+
// - "1": EP compile is disabled.
|
| 403 |
+
static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile";
|
| 404 |
+
|
| 405 |
+
// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION.
|
| 406 |
+
// "0": Allow execution with suboptimal performance. [DEFAULT]
|
| 407 |
+
// "1": Fail session creation to require recompilation for optimal performance.
|
| 408 |
+
// Note: UNSUPPORTED models always fail regardless of this setting.
|
| 409 |
+
static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel =
|
| 410 |
+
"session.fail_on_suboptimal_compiled_model";
|
| 411 |
+
|
| 412 |
+
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
|
| 413 |
+
// Meant to be used with SetEpDynamicOptions
|
| 414 |
+
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
|
| 415 |
+
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
|
| 416 |
+
// "sustained_high_performance". Default to "default".
|
| 417 |
+
static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode";
|
v1.23.1/jni/arm64-v8a/libonnxruntime.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:629d677b15742620b008fdeafafcf825e6635de5e650f2728b98cc7c659fbbe5
|
| 3 |
+
size 19343456
|
v1.23.1/jni/arm64-v8a/libonnxruntime4j_jni.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ebb4db00b87e243a6e7f401e30cac6b4172a921a59ba44bb72a53eba55e6920
|
| 3 |
+
size 100648
|
v1.23.1/jni/armeabi-v7a/libonnxruntime.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c6a7a9d851346abaf9911e6fa1bbb4433ddeb0f5c9d17429cf59942707613bb
|
| 3 |
+
size 13988872
|
v1.23.1/jni/armeabi-v7a/libonnxruntime4j_jni.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0ce9f1da045c17666540e61d0bd0861d9a6fe9044486ea78c1c34d6a8f9ecc9
|
| 3 |
+
size 73680
|
v1.23.1/jni/x86/libonnxruntime.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a78e94433c3abb77a9c7ddf2b8b74e94f7af4e6073b3288a32d6d98a68c38ed
|
| 3 |
+
size 22757348
|
v1.23.1/jni/x86/libonnxruntime4j_jni.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55bd2fc2473495913e3be73b706b8e0bd4138b561c80cacda52e487b2f47e30e
|
| 3 |
+
size 84700
|
v1.23.1/jni/x86_64/libonnxruntime.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b830967d557bd5ecf90de8357ffac3d6ec8a27f70d48d2e954e1809b0612165
|
| 3 |
+
size 23176928
|
v1.23.1/jni/x86_64/libonnxruntime4j_jni.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfd2404d809a7339bd4ab7719af93c5786d003aff2ea1a1f312b6bc3b1d64366
|
| 3 |
+
size 90728
|