From 2a2a6daf1d22f134e415204c743c1a0193f727c0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:15:53 +0000 Subject: [PATCH 01/11] Initial plan From b5bc13d93eb665ff02d7016db42fdf2a685bd137 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:25:04 +0000 Subject: [PATCH 02/11] Convert jni_layer_llama.cpp from fbjni to pure JNI Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- .../executorch/extension/llm/LlmModule.java | 83 +- extension/android/jni/jni_helper.cpp | 55 ++ extension/android/jni/jni_helper.h | 29 +- extension/android/jni/jni_layer.cpp | 8 +- extension/android/jni/jni_layer_llama.cpp | 845 ++++++++++++------ 5 files changed, 699 insertions(+), 321 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 5e080e0c369..54494979766 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -8,8 +8,6 @@ package org.pytorch.executorch.extension.llm; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import java.io.File; import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; @@ -28,18 +26,19 @@ public class LlmModule { public static final int MODEL_TYPE_TEXT_VISION = 2; public static final int MODEL_TYPE_MULTIMODAL = 2; - private final HybridData mHybridData; + private long mNativeHandle; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; - @DoNotStrip - private static native HybridData initHybrid( + private static native long nativeCreate( int modelType, String modulePath, String tokenizerPath, float temperature, List dataFiles); + private static native void nativeDestroy(long nativeHandle); + /** * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and * dataFiles. @@ -61,7 +60,7 @@ public LlmModule( throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles); + mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles); } /** @@ -107,7 +106,16 @@ public LlmModule(LlmModuleConfig config) { } public void resetNative() { - mHybridData.resetNative(); + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + } + + @Override + protected void finalize() throws Throwable { + resetNative(); + super.finalize(); } /** @@ -150,7 +158,12 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo); + public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo); + } + + private static native int nativeGenerate( + long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo); /** * Start generating tokens from the module. @@ -206,14 +219,15 @@ public int generate( */ @Experimental public long prefillImages(int[] image, int width, int height, int channels) { - int nativeResult = appendImagesInput(image, width, height, channels); + int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendImagesInput(int[] image, int width, int height, int channels); + private static native int nativeAppendImagesInput( + long nativeHandle, int[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given images input. @@ -228,15 +242,16 @@ public long prefillImages(int[] image, int width, int height, int channels) { */ @Experimental public long prefillImages(float[] image, int width, int height, int channels) { - int nativeResult = appendNormalizedImagesInput(image, width, height, channels); + int nativeResult = + nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendNormalizedImagesInput( - float[] image, int width, int height, int channels); + private static native int nativeAppendNormalizedImagesInput( + long nativeHandle, float[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given audio input. @@ -251,14 +266,15 @@ private native int appendNormalizedImagesInput( */ @Experimental public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames); + int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); + private static native int nativeAppendAudioInput( + long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given audio input. @@ -273,14 +289,16 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) */ @Experimental public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames); + int nativeResult = + nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames); + private static native int nativeAppendAudioInputFloat( + long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given raw audio input. @@ -295,15 +313,16 @@ public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames */ @Experimental public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples); + int nativeResult = + nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendRawAudioInput( - byte[] audio, int batch_size, int n_channels, int n_samples); + private static native int nativeAppendRawAudioInput( + long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples); /** * Prefill a multimodal Module with the given text input. @@ -315,7 +334,7 @@ private native int appendRawAudioInput( */ @Experimental public long prefillPrompt(String prompt) { - int nativeResult = appendTextInput(prompt); + int nativeResult = nativeAppendTextInput(mNativeHandle, prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -323,20 +342,30 @@ public long prefillPrompt(String prompt) { } // returns status - private native int appendTextInput(String prompt); + private static native int nativeAppendTextInput(long nativeHandle, String prompt); /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * *

The startPos will be reset to 0. */ - public native void resetContext(); + public void resetContext() { + nativeResetContext(mNativeHandle); + } + + private static native void nativeResetContext(long nativeHandle); /** Stop current generate() before it finishes. */ - @DoNotStrip - public native void stop(); + public void stop() { + nativeStop(mNativeHandle); + } + + private static native void nativeStop(long nativeHandle); /** Force loading the module. Otherwise the model is loaded during first generate(). */ - @DoNotStrip - public native int load(); + public int load() { + return nativeLoad(mNativeHandle); + } + + private static native int nativeLoad(long nativeHandle); } diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp index 6491524c7ac..37f9b271e52 100644 --- a/extension/android/jni/jni_helper.cpp +++ b/extension/android/jni/jni_helper.cpp @@ -10,6 +10,60 @@ namespace executorch::jni_helper { +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details) { + if (!env) { + return; + } + + // Find the exception class + jclass exceptionClass = + env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException"); + if (exceptionClass == nullptr) { + // Class not found, clear the exception and return + env->ExceptionClear(); + return; + } + + // Find the static factory method: makeExecutorchException(int, String) + jmethodID makeExceptionMethod = env->GetStaticMethodID( + exceptionClass, + "makeExecutorchException", + "(ILjava/lang/String;)Ljava/lang/RuntimeException;"); + if (makeExceptionMethod == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Create the details string + jstring jDetails = env->NewStringUTF(details.c_str()); + if (jDetails == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Call the factory method to create the exception object + jobject exception = env->CallStaticObjectMethod( + exceptionClass, + makeExceptionMethod, + static_cast(errorCode), + jDetails); + + env->DeleteLocalRef(jDetails); + + if (exception != nullptr) { + env->Throw(static_cast(exception)); + env->DeleteLocalRef(exception); + } + + env->DeleteLocalRef(exceptionClass); +} + +#if EXECUTORCH_HAS_FBJNI void throwExecutorchException(uint32_t errorCode, const std::string& details) { // Get the current JNI environment auto env = facebook::jni::Environment::current(); @@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) { auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails); facebook::jni::throwNewJavaException(exception.get()); } +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h index 898c1619d9c..683a3cfe447 100644 --- a/extension/android/jni/jni_helper.h +++ b/extension/android/jni/jni_helper.h @@ -8,9 +8,16 @@ #pragma once -#include +#include #include +#if __has_include() +#include +#define EXECUTORCH_HAS_FBJNI 1 +#else +#define EXECUTORCH_HAS_FBJNI 0 +#endif + namespace executorch::jni_helper { /** @@ -18,6 +25,25 @@ namespace executorch::jni_helper { * code and details. Uses the Java factory method * ExecutorchRuntimeException.makeExecutorchException(int, String). * + * This version takes JNIEnv* directly and works with pure JNI. + * + * @param env The JNI environment. + * @param errorCode The error code from the C++ Executorch runtime. + * @param details Additional details to include in the exception message. + */ +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details); + +#if EXECUTORCH_HAS_FBJNI +/** + * Throws a Java ExecutorchRuntimeException corresponding to the given error + * code and details. Uses the Java factory method + * ExecutorchRuntimeException.makeExecutorchException(int, String). + * + * This version uses fbjni to get the current JNI environment. + * * @param errorCode The error code from the C++ Executorch runtime. * @param details Additional details to include in the exception message. */ @@ -29,5 +55,6 @@ struct JExecutorchRuntimeException static constexpr auto kJavaDescriptor = "Lorg/pytorch/executorch/ExecutorchRuntimeException;"; }; +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 1f8457e00c5..0fbc0f14e54 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -535,10 +535,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass { } // namespace executorch::extension #ifdef EXECUTORCH_BUILD_LLAMA_JNI -extern void register_natives_for_llm(); +extern void register_natives_for_llm(JNIEnv* env); #else // No op if we don't build LLM -void register_natives_for_llm() {} +void register_natives_for_llm(JNIEnv* /* env */) {} #endif extern void register_natives_for_runtime(); @@ -552,7 +552,9 @@ void register_natives_for_training() {} JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llm(); + // Get JNIEnv for pure JNI registration in LLM + JNIEnv* env = facebook::jni::Environment::current(); + register_natives_for_llm(env); register_natives_for_runtime(); register_natives_for_training(); }); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 888e09e7989..4affb119800 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -6,9 +6,12 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include +#include #include #include #include @@ -30,9 +33,6 @@ #include #endif -#include -#include - #if defined(EXECUTORCH_BUILD_QNN) #include #endif @@ -45,6 +45,10 @@ namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; namespace { + +// Global JavaVM pointer for obtaining JNIEnv in callbacks +JavaVM* g_jvm = nullptr; + bool utf8_check_validity(const char* str, size_t length) { for (size_t i = 0; i < length; ++i) { uint8_t byte = static_cast(str[i]); @@ -79,47 +83,70 @@ bool utf8_check_validity(const char* str, size_t length) { } std::string token_buffer; -} // namespace -namespace executorch_jni { +// Helper to convert jstring to std::string +std::string jstring_to_string(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return ""; + } + const char* chars = env->GetStringUTFChars(jstr, nullptr); + if (chars == nullptr) { + return ""; + } + std::string result(chars); + env->ReleaseStringUTFChars(jstr, chars); + return result; +} -class ExecuTorchLlmCallbackJni - : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/extension/llm/LlmCallback;"; +// Helper to convert Java List to std::vector +std::vector jlist_to_string_vector(JNIEnv* env, jobject jlist) { + std::vector result; + if (jlist == nullptr) { + return result; + } - void onResult(std::string result) const { - static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); - static const auto method = - cls->getMethod)>("onResult"); + jclass list_class = env->FindClass("java/util/List"); + if (list_class == nullptr) { + env->ExceptionClear(); + return result; + } - token_buffer += result; - if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { - ET_LOG( - Info, "Current token buffer is not valid UTF-8. Waiting for more."); - return; - } - result = token_buffer; - token_buffer = ""; - facebook::jni::local_ref s = facebook::jni::make_jstring(result); - method(self(), s); + jmethodID size_method = env->GetMethodID(list_class, "size", "()I"); + jmethodID get_method = + env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;"); + + if (size_method == nullptr || get_method == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(list_class); + return result; } - void onStats(const llm::Stats& result) const { - static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); - static const auto on_stats_method = - cls->getMethod)>("onStats"); - on_stats_method( - self(), - facebook::jni::make_jstring( - executorch::extension::llm::stats_to_json_string(result))); + jint size = env->CallIntMethod(jlist, size_method); + for (jint i = 0; i < size; ++i) { + jobject str_obj = env->CallObjectMethod(jlist, get_method, i); + if (str_obj != nullptr) { + result.push_back(jstring_to_string(env, static_cast(str_obj))); + env->DeleteLocalRef(str_obj); + } } -}; -class ExecuTorchLlmJni : public facebook::jni::HybridClass { - private: - friend HybridBase; + env->DeleteLocalRef(list_class); + return result; +} + +} // namespace + +namespace executorch_jni { + +// Model type category constants +constexpr int MODEL_TYPE_CATEGORY_LLM = 1; +constexpr int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; +constexpr int MODEL_TYPE_MEDIATEK_LLAMA = 3; +constexpr int MODEL_TYPE_QNN_LLAMA = 4; + +// Native handle class that holds the runner state +class ExecuTorchLlmNative { + public: float temperature_ = 0.0f; int model_type_category_; std::unique_ptr runner_; @@ -127,37 +154,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { multi_modal_runner_; std::vector prefill_inputs_; - public: - constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/extension/llm/LlmModule;"; - - constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; - constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; - constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; - constexpr static int MODEL_TYPE_QNN_LLAMA = 4; - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - jint model_type_category, - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path, - jfloat temperature, - facebook::jni::alias_ref::javaobject> - data_files) { - return makeCxxInstance( - model_type_category, - model_path, - tokenizer_path, - temperature, - data_files); - } - - ExecuTorchLlmJni( + ExecuTorchLlmNative( + JNIEnv* env, jint model_type_category, - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path, + jstring model_path, + jstring tokenizer_path, jfloat temperature, - facebook::jni::alias_ref data_files = nullptr) { + jobject data_files) { temperature_ = temperature; #if defined(ET_USE_THREADPOOL) // Reserve 1 thread for the main thread. @@ -171,44 +174,30 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { #endif model_type_category_ = model_type_category; - std::vector data_files_vector; + std::string model_path_str = jstring_to_string(env, model_path); + std::string tokenizer_path_str = jstring_to_string(env, tokenizer_path); + std::vector data_files_vector = + jlist_to_string_vector(env, data_files); + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString())); + model_path_str.c_str(), llm::load_tokenizer(tokenizer_path_str)); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); - } - } runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector); + model_path_str, llm::load_tokenizer(tokenizer_path_str), data_files_vector); #if defined(EXECUTORCH_BUILD_QNN) } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { std::unique_ptr module = std::make_unique< executorch::extension::Module>( - model_path->toStdString().c_str(), + model_path_str.c_str(), data_files_vector, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); std::string decoder_model = "llama3"; // use llama3 for now runner_ = std::make_unique>( // QNN runner std::move(module), decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), + model_path_str.c_str(), + tokenizer_path_str.c_str(), "", ""); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; @@ -216,249 +205,525 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); + model_path_str.c_str(), tokenizer_path_str.c_str()); // Interpret the model type as LLM model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif } } +}; - jint generate( - facebook::jni::alias_ref prompt, - jint seq_len, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - if (!prompt->toStdString().empty()) { - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - } - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - multi_modal_runner_->generate( - std::move(inputs), - config, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& result) { callback->onStats(result); }); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& result) { callback->onStats(result); }); +// Helper class for callback invocation +class CallbackHelper { + public: + CallbackHelper(JNIEnv* env, jobject callback) + : env_(env), callback_(callback) { + if (callback_ != nullptr) { + callback_ = env_->NewGlobalRef(callback); + callback_class_ = env_->GetObjectClass(callback_); + on_result_method_ = env_->GetMethodID( + callback_class_, "onResult", "(Ljava/lang/String;)V"); + on_stats_method_ = + env_->GetMethodID(callback_class_, "onStats", "(Ljava/lang/String;)V"); } - return 0; } - // Returns status_code - // Contract is valid within an AAR (JNI + corresponding Java code) - jint append_text_input(facebook::jni::alias_ref prompt) { - prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return 0; + ~CallbackHelper() { + if (callback_ != nullptr) { + // Get the current JNIEnv (might be different thread) + JNIEnv* env = nullptr; + if (g_jvm != nullptr) { + int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (status == JNI_EDETACHED) { + g_jvm->AttachCurrentThread(&env, nullptr); + } + if (env != nullptr) { + env->DeleteGlobalRef(callback_); + if (callback_class_ != nullptr) { + env->DeleteGlobalRef(callback_class_); + } + } + } + } } - // Returns status_code - jint append_images_input( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels) { - std::vector images; - if (image == nullptr) { - return static_cast(Error::EndOfMethod); + void onResult(const std::string& result) { + JNIEnv* env = getEnv(); + if (env == nullptr || callback_ == nullptr || on_result_method_ == nullptr) { + return; } - auto image_size = image->size(); - if (image_size != 0) { - std::vector image_data_jint(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jint.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jint[i]; - } - llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); + + std::string current_result = result; + token_buffer += current_result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; } + current_result = token_buffer; + token_buffer = ""; - return 0; + jstring jstr = env->NewStringUTF(current_result.c_str()); + if (jstr != nullptr) { + env->CallVoidMethod(callback_, on_result_method_, jstr); + env->DeleteLocalRef(jstr); + } } - // Returns status_code - jint append_normalized_images_input( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels) { - std::vector images; - if (image == nullptr) { - return static_cast(Error::EndOfMethod); - } - auto image_size = image->size(); - if (image_size != 0) { - std::vector image_data_jfloat(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jfloat.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jfloat[i]; - } - llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); + void onStats(const llm::Stats& stats) { + JNIEnv* env = getEnv(); + if (env == nullptr || callback_ == nullptr || on_stats_method_ == nullptr) { + return; } - return 0; + std::string stats_json = + executorch::extension::llm::stats_to_json_string(stats); + jstring jstr = env->NewStringUTF(stats_json.c_str()); + if (jstr != nullptr) { + env->CallVoidMethod(callback_, on_stats_method_, jstr); + env->DeleteLocalRef(jstr); + } } - // Returns status_code - jint append_audio_input( - facebook::jni::alias_ref data, - jint batch_size, - jint n_bins, - jint n_frames) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); + private: + JNIEnv* getEnv() { + if (g_jvm == nullptr) { + return nullptr; } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } - llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + JNIEnv* env = nullptr; + int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (status == JNI_EDETACHED) { + g_jvm->AttachCurrentThread(&env, nullptr); } - return 0; + return env; } - // Returns status_code - jint append_audio_input_float( - facebook::jni::alias_ref data, - jint batch_size, - jint n_bins, - jint n_frames) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); - } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jfloat(data_size); - std::vector data_f(data_size); - data->getRegion(0, data_size, data_jfloat.data()); - for (int i = 0; i < data_size; i++) { - data_f[i] = data_jfloat[i]; - } - llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + JNIEnv* env_; + jobject callback_; + jclass callback_class_ = nullptr; + jmethodID on_result_method_ = nullptr; + jmethodID on_stats_method_ = nullptr; +}; + +} // namespace executorch_jni + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate( + JNIEnv* env, + jobject /* this */, + jint model_type_category, + jstring model_path, + jstring tokenizer_path, + jfloat temperature, + jobject data_files) { + auto* native = new executorch_jni::ExecuTorchLlmNative( + env, model_type_category, model_path, tokenizer_path, temperature, data_files); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + delete native; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jstring prompt, + jint seq_len, + jobject callback, + jboolean echo) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + std::string prompt_str = jstring_to_string(env, prompt); + + // Create a shared callback helper for use in lambdas + auto callback_helper = + std::make_shared(env, callback); + + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + std::vector inputs = native->prefill_inputs_; + native->prefill_inputs_.clear(); + if (!prompt_str.empty()) { + inputs.emplace_back(llm::MultimodalInput{prompt_str}); } - return 0; + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = native->temperature_, + }; + native->multi_modal_runner_->generate( + std::move(inputs), + config, + [callback_helper](const std::string& result) { + callback_helper->onResult(result); + }, + [callback_helper](const llm::Stats& result) { + callback_helper->onStats(result); + }); + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = native->temperature_, + }; + native->runner_->generate( + prompt_str, + config, + [callback_helper](std::string result) { + callback_helper->onResult(result); + }, + [callback_helper](const llm::Stats& result) { + callback_helper->onStats(result); + }); + } + return 0; +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return; } - // Returns status_code - jint append_raw_audio_input( - facebook::jni::alias_ref data, - jint batch_size, - jint n_channels, - jint n_samples) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + native->multi_modal_runner_->stop(); + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + native->runner_->stop(); + } +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad( + JNIEnv* env, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + int result = -1; + std::stringstream ss; + + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + result = static_cast(native->multi_modal_runner_->load()); + if (result != 0) { + ss << "Failed to load multimodal runner: [" << result << "]"; } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } - llm::RawAudio audio{ - std::move(data_u8), batch_size, n_channels, n_samples}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + result = static_cast(native->runner_->load()); + if (result != 0) { + ss << "Failed to load llm runner: [" << result << "]"; } - return 0; + } else { + ss << "Invalid model type category: " << native->model_type_category_ + << ". Valid values are: " + << executorch_jni::MODEL_TYPE_CATEGORY_LLM << " or " + << executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL; } + if (result != 0) { + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + } + return result; // 0 on success to keep backward compatibility +} - void stop() { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - multi_modal_runner_->stop(); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - runner_->stop(); +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jstring prompt) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + std::string prompt_str = jstring_to_string(env, prompt); + native->prefill_inputs_.emplace_back(llm::MultimodalInput{prompt_str}); + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jintArray image, + jint width, + jint height, + jint channels) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (image == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize image_size = env->GetArrayLength(image); + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + env->GetIntArrayRegion(image, 0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = static_cast(image_data_jint[i]); } + llm::Image image_runner{std::move(image_data), width, height, channels}; + native->prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); + } + + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jfloatArray image, + jint width, + jint height, + jint channels) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (image == nullptr) { + return static_cast(Error::EndOfMethod); } - void reset_context() { - if (runner_ != nullptr) { - runner_->reset(); + jsize image_size = env->GetArrayLength(image); + if (image_size != 0) { + std::vector image_data_jfloat(image_size); + std::vector image_data(image_size); + env->GetFloatArrayRegion(image, 0, image_size, image_data_jfloat.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jfloat[i]; } - if (multi_modal_runner_ != nullptr) { - multi_modal_runner_->reset(); + llm::Image image_runner{std::move(image_data), width, height, channels}; + native->prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); + } + + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jbyteArray data, + jint batch_size, + jint n_bins, + jint n_frames) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = static_cast(data_jbyte[i]); } + llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jfloatArray data, + jint batch_size, + jint n_bins, + jint n_frames) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; } - jint load() { - int result = -1; - std::stringstream ss; + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - result = static_cast(multi_modal_runner_->load()); - if (result != 0) { - ss << "Failed to load multimodal runner: [" << result << "]"; - } - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - result = static_cast(runner_->load()); - if (result != 0) { - ss << "Failed to load llm runner: [" << result << "]"; - } - } else { - ss << "Invalid model type category: " << model_type_category_ - << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " - << MODEL_TYPE_CATEGORY_MULTIMODAL; + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jfloat(data_size); + std::vector data_f(data_size); + env->GetFloatArrayRegion(data, 0, data_size, data_jfloat.data()); + for (int i = 0; i < data_size; i++) { + data_f[i] = data_jfloat[i]; } - if (result != 0) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); + llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jbyteArray data, + jint batch_size, + jint n_channels, + jint n_samples) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = static_cast(data_jbyte[i]); } - return result; // 0 on success to keep backward compatibility - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlmJni::generate), - makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("load", ExecuTorchLlmJni::load), - makeNativeMethod( - "appendImagesInput", ExecuTorchLlmJni::append_images_input), - makeNativeMethod( - "appendNormalizedImagesInput", - ExecuTorchLlmJni::append_normalized_images_input), - makeNativeMethod( - "appendAudioInput", ExecuTorchLlmJni::append_audio_input), - makeNativeMethod( - "appendAudioInputFloat", - ExecuTorchLlmJni::append_audio_input_float), - makeNativeMethod( - "appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input), - makeNativeMethod( - "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), - }); + llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); } -}; + return 0; +} -} // namespace executorch_jni +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return; + } + + if (native->runner_ != nullptr) { + native->runner_->reset(); + } + if (native->multi_modal_runner_ != nullptr) { + native->multi_modal_runner_->reset(); + } +} + +} // extern "C" + +void register_natives_for_llm(JNIEnv* env) { + // Store the JavaVM for later use in callbacks + env->GetJavaVM(&g_jvm); + + jclass llm_module_class = + env->FindClass("org/pytorch/executorch/extension/llm/LlmModule"); + if (llm_module_class == nullptr) { + ET_LOG(Error, "Failed to find LlmModule class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeCreate", + "(ILjava/lang/String;Ljava/lang/String;FLjava/util/List;)J", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy)}, + {"nativeGenerate", + "(JLjava/lang/String;ILorg/pytorch/executorch/extension/llm/LlmCallback;Z)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate)}, + {"nativeStop", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop)}, + {"nativeLoad", "(J)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad)}, + {"nativeAppendTextInput", "(JLjava/lang/String;)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput)}, + {"nativeAppendImagesInput", "(J[IIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput)}, + {"nativeAppendNormalizedImagesInput", "(J[FIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput)}, + {"nativeAppendAudioInput", "(J[BIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput)}, + {"nativeAppendAudioInputFloat", "(J[FIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat)}, + {"nativeAppendRawAudioInput", "(J[BIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput)}, + {"nativeResetContext", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(llm_module_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for LlmModule"); + } -void register_natives_for_llm() { - executorch_jni::ExecuTorchLlmJni::registerNatives(); + env->DeleteLocalRef(llm_module_class); } From 26608109ce2b8fdc38367adccd1a34d69bb4bb48 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:26:23 +0000 Subject: [PATCH 03/11] Fix callback_class_ memory management in CallbackHelper Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer_llama.cpp | 37 ++++++++++++----------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 4affb119800..c6844552523 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -217,10 +217,12 @@ class ExecuTorchLlmNative { class CallbackHelper { public: CallbackHelper(JNIEnv* env, jobject callback) - : env_(env), callback_(callback) { - if (callback_ != nullptr) { + : env_(env), callback_(nullptr), callback_class_(nullptr) { + if (callback != nullptr) { callback_ = env_->NewGlobalRef(callback); - callback_class_ = env_->GetObjectClass(callback_); + jclass local_class = env_->GetObjectClass(callback); + callback_class_ = static_cast(env_->NewGlobalRef(local_class)); + env_->DeleteLocalRef(local_class); on_result_method_ = env_->GetMethodID( callback_class_, "onResult", "(Ljava/lang/String;)V"); on_stats_method_ = @@ -229,20 +231,21 @@ class CallbackHelper { } ~CallbackHelper() { - if (callback_ != nullptr) { - // Get the current JNIEnv (might be different thread) - JNIEnv* env = nullptr; - if (g_jvm != nullptr) { - int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); - if (status == JNI_EDETACHED) { - g_jvm->AttachCurrentThread(&env, nullptr); - } - if (env != nullptr) { - env->DeleteGlobalRef(callback_); - if (callback_class_ != nullptr) { - env->DeleteGlobalRef(callback_class_); - } - } + if (g_jvm == nullptr) { + return; + } + // Get the current JNIEnv (might be different thread) + JNIEnv* env = nullptr; + int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (status == JNI_EDETACHED) { + g_jvm->AttachCurrentThread(&env, nullptr); + } + if (env != nullptr) { + if (callback_ != nullptr) { + env->DeleteGlobalRef(callback_); + } + if (callback_class_ != nullptr) { + env->DeleteGlobalRef(callback_class_); } } } From e4c4716c7e3d07c5dc5ccec7151dc1df0b820424 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 00:59:54 +0000 Subject: [PATCH 04/11] Convert jni_layer.cpp and related files to pure JNI Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- .../java/org/pytorch/executorch/EValue.java | 40 +- .../pytorch/executorch/ExecuTorchRuntime.java | 15 +- .../java/org/pytorch/executorch/Module.java | 76 +- .../java/org/pytorch/executorch/Tensor.java | 14 +- extension/android/jni/jni_layer.cpp | 1052 ++++++++++------- extension/android/jni/jni_layer_runtime.cpp | 107 +- extension/android/jni/jni_layer_training.cpp | 2 +- 7 files changed, 775 insertions(+), 531 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java index ab3b77ff1fb..e0122e3979e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import com.facebook.jni.annotations.DoNotStrip; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Locale; @@ -33,7 +32,6 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -@DoNotStrip public class EValue { private static final int TYPE_CODE_NONE = 0; @@ -47,52 +45,50 @@ public class EValue { "None", "Tensor", "String", "Double", "Int", "Bool", }; - @DoNotStrip private final int mTypeCode; - @DoNotStrip private Object mData; + final int mTypeCode; + Object mData; - @DoNotStrip private EValue(int typeCode) { this.mTypeCode = typeCode; } - @DoNotStrip public boolean isNone() { return TYPE_CODE_NONE == this.mTypeCode; } - @DoNotStrip + public boolean isTensor() { return TYPE_CODE_TENSOR == this.mTypeCode; } - @DoNotStrip + public boolean isBool() { return TYPE_CODE_BOOL == this.mTypeCode; } - @DoNotStrip + public boolean isInt() { return TYPE_CODE_INT == this.mTypeCode; } - @DoNotStrip + public boolean isDouble() { return TYPE_CODE_DOUBLE == this.mTypeCode; } - @DoNotStrip + public boolean isString() { return TYPE_CODE_STRING == this.mTypeCode; } /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ - @DoNotStrip + public static EValue optionalNone() { return new EValue(TYPE_CODE_NONE); } /** Creates a new {@code EValue} of type {@code Tensor}. */ - @DoNotStrip + public static EValue from(Tensor tensor) { final EValue iv = new EValue(TYPE_CODE_TENSOR); iv.mData = tensor; @@ -100,7 +96,7 @@ public static EValue from(Tensor tensor) { } /** Creates a new {@code EValue} of type {@code bool}. */ - @DoNotStrip + public static EValue from(boolean value) { final EValue iv = new EValue(TYPE_CODE_BOOL); iv.mData = value; @@ -108,7 +104,7 @@ public static EValue from(boolean value) { } /** Creates a new {@code EValue} of type {@code int}. */ - @DoNotStrip + public static EValue from(long value) { final EValue iv = new EValue(TYPE_CODE_INT); iv.mData = value; @@ -116,7 +112,7 @@ public static EValue from(long value) { } /** Creates a new {@code EValue} of type {@code double}. */ - @DoNotStrip + public static EValue from(double value) { final EValue iv = new EValue(TYPE_CODE_DOUBLE); iv.mData = value; @@ -124,38 +120,38 @@ public static EValue from(double value) { } /** Creates a new {@code EValue} of type {@code str}. */ - @DoNotStrip + public static EValue from(String value) { final EValue iv = new EValue(TYPE_CODE_STRING); iv.mData = value; return iv; } - @DoNotStrip + public Tensor toTensor() { preconditionType(TYPE_CODE_TENSOR, mTypeCode); return (Tensor) mData; } - @DoNotStrip + public boolean toBool() { preconditionType(TYPE_CODE_BOOL, mTypeCode); return (boolean) mData; } - @DoNotStrip + public long toInt() { preconditionType(TYPE_CODE_INT, mTypeCode); return (long) mData; } - @DoNotStrip + public double toDouble() { preconditionType(TYPE_CODE_DOUBLE, mTypeCode); return (double) mData; } - @DoNotStrip + public String toStr() { preconditionType(TYPE_CODE_STRING, mTypeCode); return (String) mData; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 8e2f259ef3a..dfa9f77b6dd 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; @@ -33,10 +32,16 @@ public static ExecuTorchRuntime getRuntime() { } /** Get all registered ops. */ - @DoNotStrip - public static native String[] getRegisteredOps(); + public static String[] getRegisteredOps() { + return nativeGetRegisteredOps(); + } + + private static native String[] nativeGetRegisteredOps(); /** Get all registered backends. */ - @DoNotStrip - public static native String[] getRegisteredBackends(); + public static String[] getRegisteredBackends() { + return nativeGetRegisteredBackends(); + } + + private static native String[] nativeGetRegisteredBackends(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 6da76bf4b74..481165f4e21 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -9,8 +9,6 @@ package org.pytorch.executorch; import android.util.Log; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.io.File; @@ -48,18 +46,18 @@ public class Module { /** Load mode for the module. Use memory locking and ignore errors. */ public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - private final HybridData mHybridData; + private long mNativeHandle; private final Map mMethodMetadata; - @DoNotStrip - private static native HybridData initHybrid( - String moduleAbsolutePath, int loadMode, int initHybrid); + private static native long nativeCreate(String moduleAbsolutePath, int loadMode, int numThreads); + + private static native void nativeDestroy(long nativeHandle); private Module(String moduleAbsolutePath, int loadMode, int numThreads) { ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); - mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); + mNativeHandle = nativeCreate(moduleAbsolutePath, loadMode, numThreads); mMethodMetadata = populateMethodMeta(); } @@ -75,7 +73,7 @@ Map populateMethodMeta() { return metadata; } - /** Lock protecting the non-thread safe methods in mHybridData. */ + /** Lock protecting the non-thread safe methods in native handle. */ private Lock mLock = new ReentrantLock(); /** @@ -138,18 +136,18 @@ public EValue[] forward(EValue... inputs) { public EValue[] execute(String methodName, EValue... inputs) { try { mLock.lock(); - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new EValue[0]; } - return executeNative(methodName, inputs); + return nativeExecute(mNativeHandle, methodName, inputs); } finally { mLock.unlock(); } } - @DoNotStrip - private native EValue[] executeNative(String methodName, EValue... inputs); + private static native EValue[] nativeExecute( + long nativeHandle, String methodName, EValue... inputs); /** * Load a method on this module. This might help with the first time inference performance, @@ -163,18 +161,17 @@ public EValue[] execute(String methodName, EValue... inputs) { public int loadMethod(String methodName) { try { mLock.lock(); - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return 0x2; // InvalidState } - return loadMethodNative(methodName); + return nativeLoadMethod(mNativeHandle, methodName); } finally { mLock.unlock(); } } - @DoNotStrip - private native int loadMethodNative(String methodName); + private static native int nativeLoadMethod(long nativeHandle, String methodName); /** * Returns the names of the backends in a certain method. @@ -182,16 +179,22 @@ public int loadMethod(String methodName) { * @param methodName method name to query * @return an array of backend name */ - @DoNotStrip - private native String[] getUsedBackends(String methodName); + public String[] getUsedBackends(String methodName) { + return nativeGetUsedBackends(mNativeHandle, methodName); + } + + private static native String[] nativeGetUsedBackends(long nativeHandle, String methodName); /** * Returns the names of methods. * * @return name of methods in this Module */ - @DoNotStrip - public native String[] getMethods(); + public String[] getMethods() { + return nativeGetMethods(mNativeHandle); + } + + private static native String[] nativeGetMethods(long nativeHandle); /** * Get the corresponding @MethodMetadata for a method @@ -211,20 +214,18 @@ public MethodMetadata getMethodMetadata(String name) { return methodMetadata; } - @DoNotStrip - private static native String[] readLogBufferStaticNative(); + private static native String[] nativeReadLogBufferStatic(); public static String[] readLogBufferStatic() { - return readLogBufferStaticNative(); + return nativeReadLogBufferStatic(); } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - return readLogBufferNative(); + return nativeReadLogBuffer(mNativeHandle); } - @DoNotStrip - private native String[] readLogBufferNative(); + private static native String[] nativeReadLogBuffer(long nativeHandle); /** * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. @@ -234,19 +235,25 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental - @DoNotStrip - public native boolean etdump(); + public boolean etdump() { + return nativeEtdump(mNativeHandle); + } + + private static native boolean nativeEtdump(long nativeHandle); /** * Explicitly destroys the native Module object. Calling this method is not required, as the * native object will be destroyed when this object is garbage-collected. However, the timing of * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory - * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + * more quickly. */ public void destroy() { if (mLock.tryLock()) { try { - mHybridData.resetNative(); + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } } finally { mLock.unlock(); } @@ -257,4 +264,13 @@ public void destroy() { + " released."); } } + + @Override + protected void finalize() throws Throwable { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + super.finalize(); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index e8c0a918b13..a103e3691c2 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -9,8 +9,6 @@ package org.pytorch.executorch; import android.util.Log; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -53,7 +51,7 @@ public abstract class Tensor { private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; - @DoNotStrip final long[] shape; + final long[] shape; private static final int BYTE_SIZE_BYTES = 1; private static final int INT_SIZE_BYTES = 4; @@ -468,7 +466,8 @@ public static Tensor zeros(long[] shape, DType dtype) { } } - @DoNotStrip private HybridData mHybridData; + // Native handle for tensor data (unused in pure JNI but kept for API compatibility) + private long mNativeHandle; private Tensor(long[] shape) { checkShape(shape); @@ -501,7 +500,6 @@ public long[] shape() { public abstract DType dtype(); // Called from native - @DoNotStrip int dtypeJniCode() { return dtype().jniCode; } @@ -572,7 +570,6 @@ public double[] getDataAsDoubleArray() { "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); } - @DoNotStrip Buffer getRawDataBuffer() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); @@ -889,9 +886,8 @@ private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[ // endregion checks // Called from native - @DoNotStrip private static Tensor nativeNewTensor( - ByteBuffer data, long[] shape, int dtype, HybridData hybridData) { + ByteBuffer data, long[] shape, int dtype, long nativeHandle) { Tensor tensor = null; if (DType.FLOAT.jniCode == dtype) { @@ -911,7 +907,7 @@ private static Tensor nativeNewTensor( } else { tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype)); } - tensor.mHybridData = hybridData; + tensor.mNativeHandle = nativeHandle; return tensor; } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0fbc0f14e54..8645e0dd397 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include @@ -39,223 +41,116 @@ #include #endif -#include -#include - using namespace executorch::extension; using namespace torch::executor; -namespace executorch::extension { -class TensorHybrid : public facebook::jni::HybridClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/Tensor;"; +namespace { - explicit TensorHybrid(executorch::aten::Tensor tensor) {} +// Global JavaVM pointer for obtaining JNIEnv in callbacks +JavaVM* g_jvm = nullptr; - static facebook::jni::local_ref - newJTensorFromTensor(const executorch::aten::Tensor& tensor) { - // Java wrapper currently only supports contiguous tensors. +// Helper to convert jstring to std::string +std::string jstring_to_string(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return ""; + } + const char* chars = env->GetStringUTFChars(jstr, nullptr); + if (chars == nullptr) { + return ""; + } + std::string result(chars); + env->ReleaseStringUTFChars(jstr, chars); + return result; +} - const auto scalarType = tensor.scalar_type(); - int jdtype = scalar_type_to_java_dtype.at(scalarType); - if (scalar_type_to_java_dtype.count(scalarType) == 0) { - std::stringstream ss; - ss << "executorch::aten::Tensor scalar [java] type: " << jdtype - << " is not supported on java side"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); +// EValue type codes (must match Java EValue class) +constexpr int kTypeCodeNone = 0; +constexpr int kTypeCodeTensor = 1; +constexpr int kTypeCodeString = 2; +constexpr int kTypeCodeDouble = 3; +constexpr int kTypeCodeInt = 4; +constexpr int kTypeCodeBool = 5; + +// Cached class and method IDs for performance +struct JniCache { + jclass tensor_class = nullptr; + jclass evalue_class = nullptr; + jmethodID tensor_nativeNewTensor = nullptr; + jmethodID tensor_dtypeJniCode = nullptr; + jmethodID tensor_getRawDataBuffer = nullptr; + jfieldID tensor_shape = nullptr; + jmethodID evalue_from_tensor = nullptr; + jmethodID evalue_from_long = nullptr; + jmethodID evalue_from_double = nullptr; + jmethodID evalue_from_bool = nullptr; + jmethodID evalue_from_string = nullptr; + jmethodID evalue_toTensor = nullptr; + jfieldID evalue_mTypeCode = nullptr; + jfieldID evalue_mData = nullptr; + + bool initialized = false; + + void init(JNIEnv* env) { + if (initialized) { + return; } - const auto& tensor_shape = tensor.sizes(); - std::vector tensor_shape_vec; - for (const auto& s : tensor_shape) { - tensor_shape_vec.push_back(s); - } - facebook::jni::local_ref jTensorShape = - facebook::jni::make_long_array(tensor_shape_vec.size()); - jTensorShape->setRegion( - 0, tensor_shape_vec.size(), tensor_shape_vec.data()); - - static auto cls = TensorHybrid::javaClassStatic(); - // Note: this is safe as long as the data stored in tensor is valid; the - // data won't go out of scope as long as the Method for the inference is - // valid and there is no other inference call. Java layer picks up this - // value immediately so the data is valid. - facebook::jni::local_ref jTensorBuffer = - facebook::jni::JByteBuffer::wrapBytes( - (uint8_t*)tensor.data_ptr(), tensor.nbytes()); - jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); - - static const auto jMethodNewTensor = - cls->getStaticMethod( - facebook::jni::alias_ref, - facebook::jni::alias_ref, - jint, - facebook::jni::alias_ref)>("nativeNewTensor"); - return jMethodNewTensor( - cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); - } - - static TensorPtr newTensorFromJTensor( - facebook::jni::alias_ref jtensor) { - static auto cls = TensorHybrid::javaClassStatic(); - static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); - jint jdtype = dtypeMethod(jtensor); - - static const auto shapeField = cls->getField("shape"); - auto jshape = jtensor->getFieldValue(shapeField); - - static auto dataBufferMethod = cls->getMethod< - facebook::jni::local_ref()>( - "getRawDataBuffer"); - facebook::jni::local_ref jbuffer = - dataBufferMethod(jtensor); - - const auto rank = jshape->size(); - - const auto shapeArr = jshape->getRegion(0, rank); - std::vector shape_vec; - shape_vec.reserve(rank); - - int64_t numel = 1; - for (int i = 0; i < rank; i++) { - shape_vec.push_back(shapeArr[i]); + // Cache Tensor class and methods + jclass local_tensor_class = env->FindClass("org/pytorch/executorch/Tensor"); + if (local_tensor_class != nullptr) { + tensor_class = static_cast(env->NewGlobalRef(local_tensor_class)); + env->DeleteLocalRef(local_tensor_class); + + tensor_nativeNewTensor = env->GetStaticMethodID( + tensor_class, + "nativeNewTensor", + "(Ljava/nio/ByteBuffer;[JIJ)Lorg/pytorch/executorch/Tensor;"); + tensor_dtypeJniCode = env->GetMethodID(tensor_class, "dtypeJniCode", "()I"); + tensor_getRawDataBuffer = + env->GetMethodID(tensor_class, "getRawDataBuffer", "()Ljava/nio/Buffer;"); + tensor_shape = env->GetFieldID(tensor_class, "shape", "[J"); } - for (int i = rank - 1; i >= 0; --i) { - numel *= shapeArr[i]; - } - JNIEnv* jni = facebook::jni::Environment::current(); - if (java_dtype_to_scalar_type.count(jdtype) == 0) { - std::stringstream ss; - ss << "Unknown Tensor jdtype: [" << jdtype << "]"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); - const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); - if (dataCapacity < 0) { - std::stringstream ss; - ss << "Tensor buffer is not direct or has invalid capacity"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - const size_t elementSize = executorch::runtime::elementSize(scalar_type); - const jlong expectedElements = static_cast(numel); - const jlong expectedBytes = - expectedElements * static_cast(elementSize); - const bool matchesElements = dataCapacity == expectedElements; - const bool matchesBytes = dataCapacity == expectedBytes; - if (!matchesElements && !matchesBytes) { - std::stringstream ss; - ss << "Tensor dimensions(elements number: " << numel - << ") inconsistent with buffer capacity " << dataCapacity - << " (element size bytes: " << elementSize << ")"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - return from_blob( - jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); - } - private: - friend HybridBase; -}; - -class JEValue : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/EValue;"; - - constexpr static int kTypeCodeTensor = 1; - constexpr static int kTypeCodeString = 2; - constexpr static int kTypeCodeDouble = 3; - constexpr static int kTypeCodeInt = 4; - constexpr static int kTypeCodeBool = 5; - - static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) { - if (evalue.isTensor()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::local_ref)>("from"); - return jMethodTensor( - JEValue::javaClassStatic(), - TensorHybrid::newJTensorFromTensor(evalue.toTensor())); - } else if (evalue.isInt()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jlong)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt()); - } else if (evalue.isDouble()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jdouble)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble()); - } else if (evalue.isBool()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jboolean)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool()); - } else if (evalue.isString()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::local_ref)>("from"); - std::string str = - std::string(evalue.toString().begin(), evalue.toString().end()); - return jMethodTensor( - JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); - } - std::stringstream ss; - ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return {}; - } - - static TensorPtr JEValueToTensorImpl( - facebook::jni::alias_ref JEValue) { - static const auto typeCodeField = - JEValue::javaClassStatic()->getField("mTypeCode"); - const auto typeCode = JEValue->getFieldValue(typeCodeField); - if (JEValue::kTypeCodeTensor == typeCode) { - static const auto jMethodGetTensor = - JEValue::javaClassStatic() - ->getMethod()>( - "toTensor"); - auto jtensor = jMethodGetTensor(JEValue); - return TensorHybrid::newTensorFromJTensor(jtensor); + // Cache EValue class and methods + jclass local_evalue_class = env->FindClass("org/pytorch/executorch/EValue"); + if (local_evalue_class != nullptr) { + evalue_class = static_cast(env->NewGlobalRef(local_evalue_class)); + env->DeleteLocalRef(local_evalue_class); + + evalue_from_tensor = env->GetStaticMethodID( + evalue_class, + "from", + "(Lorg/pytorch/executorch/Tensor;)Lorg/pytorch/executorch/EValue;"); + evalue_from_long = + env->GetStaticMethodID(evalue_class, "from", "(J)Lorg/pytorch/executorch/EValue;"); + evalue_from_double = + env->GetStaticMethodID(evalue_class, "from", "(D)Lorg/pytorch/executorch/EValue;"); + evalue_from_bool = + env->GetStaticMethodID(evalue_class, "from", "(Z)Lorg/pytorch/executorch/EValue;"); + evalue_from_string = env->GetStaticMethodID( + evalue_class, + "from", + "(Ljava/lang/String;)Lorg/pytorch/executorch/EValue;"); + evalue_toTensor = env->GetMethodID( + evalue_class, "toTensor", "()Lorg/pytorch/executorch/Tensor;"); + evalue_mTypeCode = env->GetFieldID(evalue_class, "mTypeCode", "I"); + evalue_mData = env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;"); } - std::stringstream ss; - ss << "Unknown EValue typeCode: " << typeCode; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return {}; + + initialized = true; } }; -class ExecuTorchJni : public facebook::jni::HybridClass { - private: - friend HybridBase; - std::unique_ptr module_; +JniCache g_jni_cache; +// Native module handle class +class ExecuTorchModuleNative { public: - constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;"; - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - facebook::jni::alias_ref modelPath, - jint loadMode, - jint numThreads) { - return makeCxxInstance(modelPath, loadMode, numThreads); - } + std::unique_ptr module_; - ExecuTorchJni( - facebook::jni::alias_ref modelPath, + ExecuTorchModuleNative( + JNIEnv* env, + jstring modelPath, jint loadMode, jint numThreads) { Module::LoadMode load_mode = Module::LoadMode::Mmap; @@ -273,17 +168,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); + std::string path = jstring_to_string(env, modelPath); + module_ = std::make_unique(path, load_mode, std::move(etdump_gen)); #ifdef ET_USE_THREADPOOL - // Default to using cores/2 threadpool threads. The long-term plan is to - // improve performant core detection in CPUInfo, but for now we can use - // cores/2 as a sane default. - // - // Based on testing, this is almost universally faster than using all - // cores, as efficiency cores can be quite slow. In extreme cases, using - // all cores can be 10x slower than using cores/2. auto threadpool = executorch::extension::threadpool::get_threadpool(); if (threadpool) { int thread_count = @@ -294,245 +182,515 @@ class ExecuTorchJni : public facebook::jni::HybridClass { } #endif } +}; + +// Helper to create Java Tensor from native tensor +jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor) { + g_jni_cache.init(env); + + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + std::stringstream ss; + ss << "executorch::aten::Tensor scalar type is not supported on java side"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; + } + int jdtype = scalar_type_to_java_dtype.at(scalarType); - facebook::jni::local_ref> execute( - facebook::jni::alias_ref methodName, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> - jinputs) { - return execute_method(methodName->toStdString(), jinputs); + // Create shape array + const auto& tensor_shape = tensor.sizes(); + jlongArray jTensorShape = env->NewLongArray(tensor_shape.size()); + if (jTensorShape == nullptr) { + return nullptr; } + std::vector shape_vec; + for (const auto& s : tensor_shape) { + shape_vec.push_back(s); + } + env->SetLongArrayRegion(jTensorShape, 0, shape_vec.size(), shape_vec.data()); + + // Create ByteBuffer wrapping tensor data + jobject jTensorBuffer = env->NewDirectByteBuffer( + const_cast(tensor.const_data_ptr()), tensor.nbytes()); + if (jTensorBuffer == nullptr) { + env->DeleteLocalRef(jTensorShape); + return nullptr; + } + + // Set byte order to native order + jclass byteBufferClass = env->FindClass("java/nio/ByteBuffer"); + jmethodID orderMethod = + env->GetMethodID(byteBufferClass, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;"); + jclass byteOrderClass = env->FindClass("java/nio/ByteOrder"); + jmethodID nativeOrderMethod = + env->GetStaticMethodID(byteOrderClass, "nativeOrder", "()Ljava/nio/ByteOrder;"); + jobject nativeOrder = env->CallStaticObjectMethod(byteOrderClass, nativeOrderMethod); + env->CallObjectMethod(jTensorBuffer, orderMethod, nativeOrder); + + env->DeleteLocalRef(byteBufferClass); + env->DeleteLocalRef(byteOrderClass); + env->DeleteLocalRef(nativeOrder); + + // Call nativeNewTensor static method (pass 0 for nativeHandle since we don't need it) + jobject result = env->CallStaticObjectMethod( + g_jni_cache.tensor_class, + g_jni_cache.tensor_nativeNewTensor, + jTensorBuffer, + jTensorShape, + jdtype, + static_cast(0)); + + env->DeleteLocalRef(jTensorBuffer); + env->DeleteLocalRef(jTensorShape); + + return result; +} + +// Helper to create native TensorPtr from Java Tensor +TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { + g_jni_cache.init(env); + + jint jdtype = env->CallIntMethod(jtensor, g_jni_cache.tensor_dtypeJniCode); + + jlongArray jshape = + static_cast(env->GetObjectField(jtensor, g_jni_cache.tensor_shape)); + + jobject jbuffer = env->CallObjectMethod(jtensor, g_jni_cache.tensor_getRawDataBuffer); + + jsize rank = env->GetArrayLength(jshape); + + std::vector shapeArr(rank); + env->GetLongArrayRegion(jshape, 0, rank, shapeArr.data()); + + std::vector shape_vec; + shape_vec.reserve(rank); - jint load_method(facebook::jni::alias_ref methodName) { - return static_cast(module_->load_method(methodName->toStdString())); + int64_t numel = 1; + for (int i = 0; i < rank; i++) { + shape_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; } - facebook::jni::local_ref> execute_method( - std::string method, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> - jinputs) { - // If no inputs is given, it will run with sample inputs (ones) - if (jinputs->size() == 0) { - auto result = module_->load_method(method); - if (result != Error::Ok) { - // Format hex string - std::stringstream ss; - ss << "Cannot get method names [Native Error: 0x" << std::hex - << std::uppercase << static_cast(result) << "]"; + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + std::stringstream ss; + ss << "Unknown Tensor jdtype: [" << jdtype << "]"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } - jni_helper::throwExecutorchException( - static_cast(result), ss.str()); - return {}; - } - auto&& underlying_method = module_->methods_[method].method; - auto&& buf = prepare_input_tensors(*underlying_method); - result = underlying_method->execute(); - if (result != Error::Ok) { - jni_helper::throwExecutorchException( - static_cast(result), - "Execution failed for method: " + method); - return {}; - } - facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray( - underlying_method->outputs_size()); - - for (int i = 0; i < underlying_method->outputs_size(); i++) { - auto jevalue = - JEValue::newJEValueFromEValue(underlying_method->get_output(i)); - jresult->setElement(i, *jevalue); - } - return jresult; - } + ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); + const jlong dataCapacity = env->GetDirectBufferCapacity(jbuffer); + if (dataCapacity < 0) { + std::stringstream ss; + ss << "Tensor buffer is not direct or has invalid capacity"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } - std::vector evalues; - std::vector tensors; - - static const auto typeCodeField = - JEValue::javaClassStatic()->getField("mTypeCode"); - - for (int i = 0; i < jinputs->size(); i++) { - auto jevalue = jinputs->getElement(i); - const auto typeCode = jevalue->getFieldValue(typeCodeField); - if (typeCode == JEValue::kTypeCodeTensor) { - tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); - evalues.emplace_back(tensors.back()); - } else if (typeCode == JEValue::kTypeCodeInt) { - int64_t value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeDouble) { - double value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeBool) { - bool value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } - } + const size_t elementSize = executorch::runtime::elementSize(scalar_type); + const jlong expectedElements = static_cast(numel); + const jlong expectedBytes = expectedElements * static_cast(elementSize); + const bool matchesElements = dataCapacity == expectedElements; + const bool matchesBytes = dataCapacity == expectedBytes; -#ifdef EXECUTORCH_ANDROID_PROFILING - auto start = std::chrono::high_resolution_clock::now(); - auto result = module_->execute(method, evalues); - auto end = std::chrono::high_resolution_clock::now(); - auto duration = - std::chrono::duration_cast(end - start) - .count(); - ET_LOG(Debug, "Execution time: %lld ms.", duration); + if (!matchesElements && !matchesBytes) { + std::stringstream ss; + ss << "Tensor dimensions(elements number: " << numel + << ") inconsistent with buffer capacity " << dataCapacity + << " (element size bytes: " << elementSize << ")"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } -#else - auto result = module_->execute(method, evalues); + void* data = env->GetDirectBufferAddress(jbuffer); + TensorPtr result = from_blob(data, shape_vec, scalar_type); -#endif + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); - if (!result.ok()) { - jni_helper::throwExecutorchException( - static_cast(result.error()), - "Execution failed for method: " + method); - return {}; - } + return result; +} - facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray(result.get().size()); +// Helper to create Java EValue from native EValue +jobject newJEValueFromEValue(JNIEnv* env, EValue evalue) { + g_jni_cache.init(env); - for (int i = 0; i < result.get().size(); i++) { - auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); - jresult->setElement(i, *jevalue); + if (evalue.isTensor()) { + jobject jtensor = newJTensorFromTensor(env, evalue.toTensor()); + if (jtensor == nullptr) { + return nullptr; } - return jresult; + jobject result = env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_tensor, jtensor); + env->DeleteLocalRef(jtensor); + return result; + } else if (evalue.isInt()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_long, evalue.toInt()); + } else if (evalue.isDouble()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_double, evalue.toDouble()); + } else if (evalue.isBool()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, + g_jni_cache.evalue_from_bool, + static_cast(evalue.toBool())); + } else if (evalue.isString()) { + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + jstring jstr = env->NewStringUTF(str.c_str()); + jobject result = env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_string, jstr); + env->DeleteLocalRef(jstr); + return result; } - facebook::jni::local_ref> - readLogBuffer() { - return readLogBufferUtil(); + std::stringstream ss; + ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; +} + +// Helper to get TensorPtr from Java EValue +TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) { + g_jni_cache.init(env); + + jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode); + if (typeCode == kTypeCodeTensor) { + jobject jtensor = + env->CallObjectMethod(jevalue, g_jni_cache.evalue_toTensor); + TensorPtr result = newTensorFromJTensor(env, jtensor); + env->DeleteLocalRef(jtensor); + return result; } - static facebook::jni::local_ref> - readLogBufferStatic(facebook::jni::alias_ref) { - return readLogBufferUtil(); + std::stringstream ss; + ss << "Unknown EValue typeCode: " << typeCode; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; +} + +} // namespace + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_Module_nativeCreate( + JNIEnv* env, + jclass /* clazz */, + jstring modelPath, + jint loadMode, + jint numThreads) { + auto* native = new ExecuTorchModuleNative(env, modelPath, loadMode, numThreads); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_Module_nativeDestroy( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { + auto* native = reinterpret_cast(nativeHandle); + delete native; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeExecute( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName, + jobjectArray jinputs) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; } - static facebook::jni::local_ref> - readLogBufferUtil() { -#ifdef __ANDROID__ + g_jni_cache.init(env); + + std::string method = jstring_to_string(env, methodName); + jsize inputSize = jinputs != nullptr ? env->GetArrayLength(jinputs) : 0; + + // If no inputs is given, it will run with sample inputs (ones) + if (inputSize == 0) { + auto result = native->module_->load_method(method); + if (result != Error::Ok) { + std::stringstream ss; + ss << "Cannot get method names [Native Error: 0x" << std::hex + << std::uppercase << static_cast(result) << "]"; + jni_helper::throwExecutorchException( + env, static_cast(result), ss.str()); + return nullptr; + } + auto&& underlying_method = native->module_->methods_[method].method; + auto&& buf = prepare_input_tensors(*underlying_method); + result = underlying_method->execute(); + if (result != Error::Ok) { + jni_helper::throwExecutorchException( + env, static_cast(result), "Execution failed for method: " + method); + return nullptr; + } - facebook::jni::local_ref> ret; - - access_log_buffer([&](std::vector& buffer) { - const auto size = buffer.size(); - ret = facebook::jni::JArrayClass::newArray(size); - for (auto i = 0u; i < size; i++) { - const auto& entry = buffer[i]; - // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL - // MESSAGE". - std::stringstream ss; - ss << "[" << entry.timestamp << " " << entry.function << " " - << entry.filename << ":" << entry.line << "] " - << static_cast(entry.level) << " " << entry.message; - - facebook::jni::local_ref jstr_message = - facebook::jni::make_jstring(ss.str().c_str()); - (*ret)[i] = jstr_message; + jobjectArray jresult = + env->NewObjectArray(underlying_method->outputs_size(), g_jni_cache.evalue_class, nullptr); + + for (int i = 0; i < underlying_method->outputs_size(); i++) { + jobject jevalue = newJEValueFromEValue(env, underlying_method->get_output(i)); + env->SetObjectArrayElement(jresult, i, jevalue); + if (jevalue != nullptr) { + env->DeleteLocalRef(jevalue); } - }); + } + return jresult; + } - return ret; -#else - return facebook::jni::JArrayClass::newArray(0); -#endif + std::vector evalues; + std::vector tensors; + + for (int i = 0; i < inputSize; i++) { + jobject jevalue = env->GetObjectArrayElement(jinputs, i); + jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode); + + if (typeCode == kTypeCodeTensor) { + tensors.emplace_back(JEValueToTensorImpl(env, jevalue)); + evalues.emplace_back(tensors.back()); + } else if (typeCode == kTypeCodeInt) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass longClass = env->FindClass("java/lang/Long"); + jmethodID longValue = env->GetMethodID(longClass, "longValue", "()J"); + jlong value = env->CallLongMethod(mData, longValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(longClass); + } else if (typeCode == kTypeCodeDouble) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass doubleClass = env->FindClass("java/lang/Double"); + jmethodID doubleValue = env->GetMethodID(doubleClass, "doubleValue", "()D"); + jdouble value = env->CallDoubleMethod(mData, doubleValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(doubleClass); + } else if (typeCode == kTypeCodeBool) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass boolClass = env->FindClass("java/lang/Boolean"); + jmethodID boolValue = env->GetMethodID(boolClass, "booleanValue", "()Z"); + jboolean value = env->CallBooleanMethod(mData, boolValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(boolClass); + } + env->DeleteLocalRef(jevalue); } - jboolean etdump() { #ifdef EXECUTORCH_ANDROID_PROFILING - executorch::etdump::ETDumpGen* etdumpgen = - (executorch::etdump::ETDumpGen*)module_->event_tracer(); - auto etdump_data = etdumpgen->get_etdump_data(); - - if (etdump_data.buf != nullptr && etdump_data.size > 0) { - int etdump_file = - open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); - if (etdump_file == -1) { - ET_LOG(Error, "Cannot create result.etdump error: %d", errno); - return false; - } - ssize_t bytes_written = - write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); - if (bytes_written == -1) { - ET_LOG(Error, "Cannot write result.etdump error: %d", errno); - return false; - } else { - ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); - } - close(etdump_file); - free(etdump_data.buf); - return true; - } else { - ET_LOG(Error, "No ETDump data available!"); - } + auto start = std::chrono::high_resolution_clock::now(); + auto result = native->module_->execute(method, evalues); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start).count(); + ET_LOG(Debug, "Execution time: %lld ms.", duration); +#else + auto result = native->module_->execute(method, evalues); #endif - return false; + + if (!result.ok()) { + jni_helper::throwExecutorchException( + env, + static_cast(result.error()), + "Execution failed for method: " + method); + return nullptr; } - facebook::jni::local_ref> getMethods() { - const auto& names_result = module_->method_names(); - if (!names_result.ok()) { - // Format hex string - std::stringstream ss; - ss << "Cannot get load module [Native Error: 0x" << std::hex - << std::uppercase << static_cast(names_result.error()) - << "]"; + jobjectArray jresult = + env->NewObjectArray(result.get().size(), g_jni_cache.evalue_class, nullptr); - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str()); - return {}; - } - const auto& methods = names_result.get(); - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(methods.size()); - int i = 0; - for (auto s : methods) { - facebook::jni::local_ref method_name = - facebook::jni::make_jstring(s.c_str()); - (*ret)[i] = method_name; - i++; + for (size_t i = 0; i < result.get().size(); i++) { + jobject jevalue = newJEValueFromEValue(env, result.get()[i]); + env->SetObjectArrayElement(jresult, i, jevalue); + if (jevalue != nullptr) { + env->DeleteLocalRef(jevalue); } - return ret; } + return jresult; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_Module_nativeLoadMethod( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return -1; + } + std::string method = jstring_to_string(env, methodName); + return static_cast(native->module_->load_method(method)); +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeGetMethods( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; + } + + const auto& names_result = native->module_->method_names(); + if (!names_result.ok()) { + std::stringstream ss; + ss << "Cannot get load module [Native Error: 0x" << std::hex + << std::uppercase << static_cast(names_result.error()) << "]"; + jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str()); + return nullptr; + } + + const auto& methods = names_result.get(); + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(methods.size(), stringClass, nullptr); - facebook::jni::local_ref> getUsedBackends( - facebook::jni::alias_ref methodName) { - auto methodMeta = module_->method_meta(methodName->toStdString()).get(); - std::unordered_set backends; - for (auto i = 0; i < methodMeta.num_backends(); i++) { - backends.insert(methodMeta.get_backend_name(i).get()); + int i = 0; + for (auto s : methods) { + jstring method_name = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(ret, i, method_name); + env->DeleteLocalRef(method_name); + i++; + } + env->DeleteLocalRef(stringClass); + return ret; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeGetUsedBackends( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; + } + + std::string method = jstring_to_string(env, methodName); + auto methodMeta = native->module_->method_meta(method).get(); + std::unordered_set backends; + for (auto i = 0; i < methodMeta.num_backends(); i++) { + backends.insert(methodMeta.get_backend_name(i).get()); + } + + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(backends.size(), stringClass, nullptr); + + int i = 0; + for (auto s : backends) { + jstring backend_name = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(ret, i, backend_name); + env->DeleteLocalRef(backend_name); + i++; + } + env->DeleteLocalRef(stringClass); + return ret; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeReadLogBuffer( + JNIEnv* env, + jclass /* clazz */, + jlong /* nativeHandle */) { +#ifdef __ANDROID__ + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = nullptr; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = env->NewObjectArray(size, stringClass, nullptr); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + jstring jstr_message = env->NewStringUTF(ss.str().c_str()); + env->SetObjectArrayElement(ret, i, jstr_message); + env->DeleteLocalRef(jstr_message); } + }); + + env->DeleteLocalRef(stringClass); + return ret; +#else + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(0, stringClass, nullptr); + env->DeleteLocalRef(stringClass); + return ret; +#endif +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic( + JNIEnv* env, + jclass clazz) { + return Java_org_pytorch_executorch_Module_nativeReadLogBuffer(env, clazz, 0); +} + +JNIEXPORT jboolean JNICALL +Java_org_pytorch_executorch_Module_nativeEtdump( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { +#ifdef EXECUTORCH_ANDROID_PROFILING + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return JNI_FALSE; + } + + executorch::etdump::ETDumpGen* etdumpgen = + (executorch::etdump::ETDumpGen*)native->module_->event_tracer(); + auto etdump_data = etdumpgen->get_etdump_data(); - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(backends.size()); - int i = 0; - for (auto s : backends) { - facebook::jni::local_ref backend_name = - facebook::jni::make_jstring(s.c_str()); - (*ret)[i] = backend_name; - i++; + if (etdump_data.buf != nullptr && etdump_data.size > 0) { + int etdump_file = + open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); + if (etdump_file == -1) { + ET_LOG(Error, "Cannot create result.etdump error: %d", errno); + return JNI_FALSE; } - return ret; - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), - makeNativeMethod("executeNative", ExecuTorchJni::execute), - makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method), - makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), - makeNativeMethod( - "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdump", ExecuTorchJni::etdump), - makeNativeMethod("getMethods", ExecuTorchJni::getMethods), - makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), - }); + ssize_t bytes_written = + write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); + if (bytes_written == -1) { + ET_LOG(Error, "Cannot write result.etdump error: %d", errno); + return JNI_FALSE; + } else { + ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); + } + close(etdump_file); + free(etdump_data.buf); + return JNI_TRUE; + } else { + ET_LOG(Error, "No ETDump data available!"); } -}; -} // namespace executorch::extension +#endif + return JNI_FALSE; +} + +} // extern "C" #ifdef EXECUTORCH_BUILD_LLAMA_JNI extern void register_natives_for_llm(JNIEnv* env); @@ -540,22 +698,72 @@ extern void register_natives_for_llm(JNIEnv* env); // No op if we don't build LLM void register_natives_for_llm(JNIEnv* /* env */) {} #endif -extern void register_natives_for_runtime(); #ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING -extern void register_natives_for_training(); +extern void register_natives_for_training(JNIEnv* env); #else // No op if we don't build training JNI -void register_natives_for_training() {} +void register_natives_for_training(JNIEnv* /* env */) {} #endif +void register_natives_for_runtime(JNIEnv* env); + +void register_natives_for_module(JNIEnv* env) { + jclass module_class = env->FindClass("org/pytorch/executorch/Module"); + if (module_class == nullptr) { + ET_LOG(Error, "Failed to find Module class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeCreate", "(Ljava/lang/String;II)J", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeDestroy)}, + {"nativeExecute", + "(JLjava/lang/String;[Lorg/pytorch/executorch/EValue;)[Lorg/pytorch/executorch/EValue;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeExecute)}, + {"nativeLoadMethod", "(JLjava/lang/String;)I", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeLoadMethod)}, + {"nativeGetMethods", "(J)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetMethods)}, + {"nativeGetUsedBackends", "(JLjava/lang/String;)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetUsedBackends)}, + {"nativeReadLogBuffer", "(J)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBuffer)}, + {"nativeReadLogBufferStatic", "()[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic)}, + {"nativeEtdump", "(J)Z", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeEtdump)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(module_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for Module"); + } + + env->DeleteLocalRef(module_class); +} + JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize(vm, [] { - executorch::extension::ExecuTorchJni::registerNatives(); - // Get JNIEnv for pure JNI registration in LLM - JNIEnv* env = facebook::jni::Environment::current(); - register_natives_for_llm(env); - register_natives_for_runtime(); - register_natives_for_training(); - }); + g_jvm = vm; + JNIEnv* env = nullptr; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + // Initialize the JNI cache + g_jni_cache.init(env); + + // Register native methods + register_natives_for_module(env); + register_natives_for_llm(env); + register_natives_for_runtime(env); + register_natives_for_training(env); + + return JNI_VERSION_1_6; } diff --git a/extension/android/jni/jni_layer_runtime.cpp b/extension/android/jni/jni_layer_runtime.cpp index 890e1d0fad9..32e7866353a 100644 --- a/extension/android/jni/jni_layer_runtime.cpp +++ b/extension/android/jni/jni_layer_runtime.cpp @@ -6,67 +6,90 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include +#include namespace executorch_jni { namespace runtime = ::executorch::ET_RUNTIME_NAMESPACE; -class AndroidRuntimeJni : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/ExecuTorchRuntime;"; - - static void registerNatives() { - javaClassStatic()->registerNatives({ - makeNativeMethod( - "getRegisteredOps", AndroidRuntimeJni::getRegisteredOps), - makeNativeMethod( - "getRegisteredBackends", AndroidRuntimeJni::getRegisteredBackends), - }); - } +} // namespace executorch_jni - // Returns a string array of all registered ops - static facebook::jni::local_ref> - getRegisteredOps(facebook::jni::alias_ref) { - auto kernels = runtime::get_registered_kernels(); - auto result = facebook::jni::JArrayClass::newArray(kernels.size()); +extern "C" { - for (size_t i = 0; i < kernels.size(); ++i) { - auto op = facebook::jni::make_jstring(kernels[i].name_); - result->setElement(i, op.get()); - } +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps( + JNIEnv* env, + jclass /* clazz */) { + auto kernels = executorch_jni::runtime::get_registered_kernels(); + + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray result = env->NewObjectArray(kernels.size(), stringClass, nullptr); - return result; + for (size_t i = 0; i < kernels.size(); ++i) { + jstring op = env->NewStringUTF(kernels[i].name_); + env->SetObjectArrayElement(result, i, op); + env->DeleteLocalRef(op); } - // Returns a string array of all registered backends - static facebook::jni::local_ref> - getRegisteredBackends(facebook::jni::alias_ref) { - int num_backends = runtime::get_num_registered_backends(); - auto result = facebook::jni::JArrayClass::newArray(num_backends); + env->DeleteLocalRef(stringClass); + return result; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends( + JNIEnv* env, + jclass /* clazz */) { + int num_backends = executorch_jni::runtime::get_num_registered_backends(); - for (int i = 0; i < num_backends; ++i) { - auto name_result = runtime::get_backend_name(i); - const char* name = ""; + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray result = env->NewObjectArray(num_backends, stringClass, nullptr); - if (name_result.ok()) { - name = *name_result; - } + for (int i = 0; i < num_backends; ++i) { + auto name_result = executorch_jni::runtime::get_backend_name(i); + const char* name = ""; - auto backend_str = facebook::jni::make_jstring(name); - result->setElement(i, backend_str.get()); + if (name_result.ok()) { + name = *name_result; } - return result; + jstring backend_str = env->NewStringUTF(name); + env->SetObjectArrayElement(result, i, backend_str); + env->DeleteLocalRef(backend_str); } -}; -} // namespace executorch_jni + env->DeleteLocalRef(stringClass); + return result; +} + +} // extern "C" + +void register_natives_for_runtime(JNIEnv* env) { + jclass runtime_class = env->FindClass("org/pytorch/executorch/ExecuTorchRuntime"); + if (runtime_class == nullptr) { + ET_LOG(Error, "Failed to find ExecuTorchRuntime class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeGetRegisteredOps", "()[Ljava/lang/String;", + reinterpret_cast( + Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps)}, + {"nativeGetRegisteredBackends", "()[Ljava/lang/String;", + reinterpret_cast( + Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(runtime_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for ExecuTorchRuntime"); + } -void register_natives_for_runtime() { - executorch_jni::AndroidRuntimeJni::registerNatives(); + env->DeleteLocalRef(runtime_class); } diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 5a5e9f24d2f..26308f86731 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -345,7 +345,7 @@ class SGDHybrid : public facebook::jni::HybridClass { } // namespace executorch::extension // Function to register training module natives -void register_natives_for_training() { +void register_natives_for_training(JNIEnv* /* env */) { executorch::extension::ExecuTorchTrainingJni::registerNatives(); executorch::extension::SGDHybrid::registerNatives(); }; From 1a810d991ec73977e5b33eaf26e3fddacf78e90e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:03:37 +0000 Subject: [PATCH 05/11] Add null checks for native handle destruction Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer.cpp | 6 ++++-- extension/android/jni/jni_layer_llama.cpp | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 8645e0dd397..5fc89a658bf 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -404,8 +404,10 @@ Java_org_pytorch_executorch_Module_nativeDestroy( JNIEnv* /* env */, jclass /* clazz */, jlong nativeHandle) { - auto* native = reinterpret_cast(nativeHandle); - delete native; + if (nativeHandle != 0) { + auto* native = reinterpret_cast(nativeHandle); + delete native; + } } JNIEXPORT jobjectArray JNICALL diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index c6844552523..e26b40f4f4b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -331,9 +331,11 @@ Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy( JNIEnv* /* env */, jobject /* this */, jlong native_handle) { - auto* native = - reinterpret_cast(native_handle); - delete native; + if (native_handle != 0) { + auto* native = + reinterpret_cast(native_handle); + delete native; + } } JNIEXPORT jint JNICALL From 89e7baa8a9a44a14e8972109b23b9d9798cce79d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 02:42:24 +0000 Subject: [PATCH 06/11] Fix namespace for jni_helper in jni_layer.cpp Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 5fc89a658bf..0a8f38890ea 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -192,7 +192,7 @@ jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor if (scalar_type_to_java_dtype.count(scalarType) == 0) { std::stringstream ss; ss << "executorch::aten::Tensor scalar type is not supported on java side"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } @@ -277,7 +277,7 @@ TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { if (java_dtype_to_scalar_type.count(jdtype) == 0) { std::stringstream ss; ss << "Unknown Tensor jdtype: [" << jdtype << "]"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); env->DeleteLocalRef(jshape); env->DeleteLocalRef(jbuffer); @@ -289,7 +289,7 @@ TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { if (dataCapacity < 0) { std::stringstream ss; ss << "Tensor buffer is not direct or has invalid capacity"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); env->DeleteLocalRef(jshape); env->DeleteLocalRef(jbuffer); @@ -307,7 +307,7 @@ TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { ss << "Tensor dimensions(elements number: " << numel << ") inconsistent with buffer capacity " << dataCapacity << " (element size bytes: " << elementSize << ")"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); env->DeleteLocalRef(jshape); env->DeleteLocalRef(jbuffer); @@ -359,7 +359,7 @@ jobject newJEValueFromEValue(JNIEnv* env, EValue evalue) { std::stringstream ss; ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } @@ -379,7 +379,7 @@ TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) { std::stringstream ss; ss << "Unknown EValue typeCode: " << typeCode; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str().c_str()); return nullptr; } @@ -434,7 +434,7 @@ Java_org_pytorch_executorch_Module_nativeExecute( std::stringstream ss; ss << "Cannot get method names [Native Error: 0x" << std::hex << std::uppercase << static_cast(result) << "]"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(result), ss.str()); return nullptr; } @@ -442,7 +442,7 @@ Java_org_pytorch_executorch_Module_nativeExecute( auto&& buf = prepare_input_tensors(*underlying_method); result = underlying_method->execute(); if (result != Error::Ok) { - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(result), "Execution failed for method: " + method); return nullptr; } @@ -510,7 +510,7 @@ Java_org_pytorch_executorch_Module_nativeExecute( #endif if (!result.ok()) { - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(result.error()), "Execution failed for method: " + method); @@ -559,7 +559,7 @@ Java_org_pytorch_executorch_Module_nativeGetMethods( std::stringstream ss; ss << "Cannot get load module [Native Error: 0x" << std::hex << std::uppercase << static_cast(names_result.error()) << "]"; - jni_helper::throwExecutorchException( + executorch::jni_helper::throwExecutorchException( env, static_cast(Error::InvalidArgument), ss.str()); return nullptr; } From 7128cc0843d9b2effbb0e0096727ecba1d8d95c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 03:14:13 +0000 Subject: [PATCH 07/11] Fix protected member access by using friend class ExecuTorchJni Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer.cpp | 44 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0a8f38890ea..f8f460afa98 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -44,13 +44,8 @@ using namespace executorch::extension; using namespace torch::executor; -namespace { - -// Global JavaVM pointer for obtaining JNIEnv in callbacks -JavaVM* g_jvm = nullptr; - -// Helper to convert jstring to std::string -std::string jstring_to_string(JNIEnv* env, jstring jstr) { +// Helper to convert jstring to std::string (defined outside namespace for broad access) +static std::string jstring_to_string(JNIEnv* env, jstring jstr) { if (jstr == nullptr) { return ""; } @@ -63,6 +58,11 @@ std::string jstring_to_string(JNIEnv* env, jstring jstr) { return result; } +namespace { + +// Global JavaVM pointer for obtaining JNIEnv in callbacks +JavaVM* g_jvm = nullptr; + // EValue type codes (must match Java EValue class) constexpr int kTypeCodeNone = 0; constexpr int kTypeCodeTensor = 1; @@ -137,18 +137,22 @@ struct JniCache { evalue_mData = env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;"); } - initialized = true; + initialized = true; } }; JniCache g_jni_cache; -// Native module handle class -class ExecuTorchModuleNative { +} // anonymous namespace + +namespace executorch::extension { + +// Native module handle class - named ExecuTorchJni to match friend declaration in Module +class ExecuTorchJni { public: std::unique_ptr module_; - ExecuTorchModuleNative( + ExecuTorchJni( JNIEnv* env, jstring modelPath, jint loadMode, @@ -184,6 +188,10 @@ class ExecuTorchModuleNative { } }; +} // namespace executorch::extension + +namespace { + // Helper to create Java Tensor from native tensor jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor) { g_jni_cache.init(env); @@ -395,7 +403,7 @@ Java_org_pytorch_executorch_Module_nativeCreate( jstring modelPath, jint loadMode, jint numThreads) { - auto* native = new ExecuTorchModuleNative(env, modelPath, loadMode, numThreads); + auto* native = new executorch::extension::ExecuTorchJni(env, modelPath, loadMode, numThreads); return reinterpret_cast(native); } @@ -405,7 +413,7 @@ Java_org_pytorch_executorch_Module_nativeDestroy( jclass /* clazz */, jlong nativeHandle) { if (nativeHandle != 0) { - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); delete native; } } @@ -417,7 +425,7 @@ Java_org_pytorch_executorch_Module_nativeExecute( jlong nativeHandle, jstring methodName, jobjectArray jinputs) { - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); if (native == nullptr) { return nullptr; } @@ -536,7 +544,7 @@ Java_org_pytorch_executorch_Module_nativeLoadMethod( jclass /* clazz */, jlong nativeHandle, jstring methodName) { - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); if (native == nullptr) { return -1; } @@ -549,7 +557,7 @@ Java_org_pytorch_executorch_Module_nativeGetMethods( JNIEnv* env, jclass /* clazz */, jlong nativeHandle) { - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); if (native == nullptr) { return nullptr; } @@ -585,7 +593,7 @@ Java_org_pytorch_executorch_Module_nativeGetUsedBackends( jclass /* clazz */, jlong nativeHandle, jstring methodName) { - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); if (native == nullptr) { return nullptr; } @@ -658,7 +666,7 @@ Java_org_pytorch_executorch_Module_nativeEtdump( jclass /* clazz */, jlong nativeHandle) { #ifdef EXECUTORCH_ANDROID_PROFILING - auto* native = reinterpret_cast(nativeHandle); + auto* native = reinterpret_cast(nativeHandle); if (native == nullptr) { return JNI_FALSE; } From ae48cde0f00754bb9c0b5b535d71081b586ac473 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 09:10:27 +0000 Subject: [PATCH 08/11] Fix protected methods_ access via friend class member function Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index f8f460afa98..636e4de0b36 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -186,6 +186,15 @@ class ExecuTorchJni { } #endif } + + // Access protected methods_ member (friend class privilege) + Method* get_method(const std::string& method_name) { + auto it = module_->methods_.find(method_name); + if (it != module_->methods_.end()) { + return it->second.method.get(); + } + return nullptr; + } }; } // namespace executorch::extension @@ -446,7 +455,12 @@ Java_org_pytorch_executorch_Module_nativeExecute( env, static_cast(result), ss.str()); return nullptr; } - auto&& underlying_method = native->module_->methods_[method].method; + auto* underlying_method = native->get_method(method); + if (underlying_method == nullptr) { + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), "Method not found: " + method); + return nullptr; + } auto&& buf = prepare_input_tensors(*underlying_method); result = underlying_method->execute(); if (result != Error::Ok) { From fa6f82ac903eda6be93fb7e050d09507562a8c9b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 8 Jan 2026 18:29:12 +0000 Subject: [PATCH 09/11] Add full TensorHybrid/JEValue implementations for training module Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer_training.cpp | 149 ++++++++++++++++++- 1 file changed, 144 insertions(+), 5 deletions(-) diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 26308f86731..03654f5d4ae 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -12,10 +12,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -28,19 +30,97 @@ using namespace torch::executor; namespace executorch::extension { -// Forward declarations from jni_layer.cpp +// Full implementation of TensorHybrid for training module (fbjni-based) class TensorHybrid : public facebook::jni::HybridClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/Tensor;"; static facebook::jni::local_ref - newJTensorFromTensor(const executorch::aten::Tensor& tensor); + newJTensorFromTensor(const executorch::aten::Tensor& tensor) { + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "executorch::aten::Tensor scalar type %d is not supported on java side", + static_cast(scalarType)); + } + int jdtype = scalar_type_to_java_dtype.at(scalarType); + + const auto& tensor_shape = tensor.sizes(); + std::vector tensor_shape_vec; + for (const auto& s : tensor_shape) { + tensor_shape_vec.push_back(s); + } + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensor_shape_vec.size()); + jTensorShape->setRegion( + 0, tensor_shape_vec.size(), tensor_shape_vec.data()); + + facebook::jni::local_ref jTensorBuffer = + facebook::jni::JByteBuffer::wrapBytes( + (uint8_t*)tensor.const_data_ptr(), tensor.nbytes()); + jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); + + static auto cls = TensorHybrid::javaClassStatic(); + static const auto jMethodNewTensor = + cls->getStaticMethod( + facebook::jni::local_ref, + facebook::jni::local_ref, + jint, + facebook::jni::local_ref)>("nativeNewTensor"); + return jMethodNewTensor( + cls, std::move(jTensorBuffer), std::move(jTensorShape), jdtype, nullptr); + } static TensorPtr newTensorFromJTensor( - facebook::jni::alias_ref jtensor); + facebook::jni::alias_ref jtensor) { + static const auto dtypeMethod = + TensorHybrid::javaClassStatic()->getMethod("dtypeJniCode"); + jint jdtype = dtypeMethod(jtensor); + + static auto shapeField = + TensorHybrid::javaClassStatic()->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); + + static const auto dataBufferMethod = + TensorHybrid::javaClassStatic() + ->getMethod()>( + "getRawDataBuffer"); + facebook::jni::local_ref jbuffer = + dataBufferMethod(jtensor); + + const auto rank = jshape->size(); + + std::vector shapeArr(rank); + jshape->getRegion(0, rank, shapeArr.data()); + + std::vector sizes_vec; + sizes_vec.reserve(rank); + + int64_t numel = 1; + for (int i = 0; i < rank; i++) { + sizes_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; + } + + JNIEnv* jni = jbuffer->getPlainJniEnv(); + void* dataPtr = jni->GetDirectBufferAddress(jbuffer.get()); + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown Tensor jdtype: %d", + jdtype); + } + + ScalarType scalarType = java_dtype_to_scalar_type.at(jdtype); + return from_blob(dataPtr, sizes_vec, scalarType); + } }; +// Full implementation of JEValue for training module (fbjni-based) class JEValue : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = @@ -53,10 +133,69 @@ class JEValue : public facebook::jni::JavaClass { constexpr static int kTypeCodeBool = 5; static facebook::jni::local_ref newJEValueFromEValue( - runtime::EValue evalue); + runtime::EValue evalue) { + if (evalue.isTensor()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + return jMethodTensor( + JEValue::javaClassStatic(), + TensorHybrid::newJTensorFromTensor(evalue.toTensor())); + } else if (evalue.isInt()) { + static auto jMethodInt = + JEValue::javaClassStatic() + ->getStaticMethod(jlong)>( + "from"); + return jMethodInt(JEValue::javaClassStatic(), evalue.toInt()); + } else if (evalue.isDouble()) { + static auto jMethodDouble = + JEValue::javaClassStatic() + ->getStaticMethod(jdouble)>( + "from"); + return jMethodDouble(JEValue::javaClassStatic(), evalue.toDouble()); + } else if (evalue.isBool()) { + static auto jMethodBool = + JEValue::javaClassStatic() + ->getStaticMethod(jboolean)>( + "from"); + return jMethodBool(JEValue::javaClassStatic(), evalue.toBool()); + } else if (evalue.isString()) { + static auto jMethodStr = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + return jMethodStr( + JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); + } + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown EValue type: %d", + static_cast(evalue.tag)); + return nullptr; + } static TensorPtr JEValueToTensorImpl( - facebook::jni::alias_ref JEValue); + facebook::jni::alias_ref jevalue) { + static const auto typeCodeField = + JEValue::javaClassStatic()->getField("mTypeCode"); + const auto typeCode = jevalue->getFieldValue(typeCodeField); + if (typeCode == JEValue::kTypeCodeTensor) { + static const auto jMethodGetTensor = + JEValue::javaClassStatic() + ->getMethod()>( + "toTensor"); + auto tensor = jMethodGetTensor(jevalue); + return TensorHybrid::newTensorFromJTensor(tensor); + } + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown EValue typeCode: %d", + typeCode); + return nullptr; + } }; class ExecuTorchTrainingJni From e18fe7b5ac837f6859df3fafde40ad329e2c04e6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:30:45 +0000 Subject: [PATCH 10/11] Fix JNIEnv retrieval in jni_layer_training.cpp Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- extension/android/jni/jni_layer_training.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 03654f5d4ae..0641013a993 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -106,7 +106,7 @@ class TensorHybrid : public facebook::jni::HybridClass { numel *= shapeArr[i]; } - JNIEnv* jni = jbuffer->getPlainJniEnv(); + JNIEnv* jni = facebook::jni::Environment::current(); void* dataPtr = jni->GetDirectBufferAddress(jbuffer.get()); if (java_dtype_to_scalar_type.count(jdtype) == 0) { facebook::jni::throwNewJavaException( From 0485e18857b37febf8e1d28438574ef22be4b779 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:15:27 +0000 Subject: [PATCH 11/11] Revert jni_layer_llama.cpp and LlmModule.java to split into separate PR Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com> --- .../executorch/extension/llm/LlmModule.java | 83 +- extension/android/jni/jni_layer.cpp | 6 +- extension/android/jni/jni_layer_llama.cpp | 850 ++++++------------ 3 files changed, 320 insertions(+), 619 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 54494979766..5e080e0c369 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -8,6 +8,8 @@ package org.pytorch.executorch.extension.llm; +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; import java.io.File; import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; @@ -26,19 +28,18 @@ public class LlmModule { public static final int MODEL_TYPE_TEXT_VISION = 2; public static final int MODEL_TYPE_MULTIMODAL = 2; - private long mNativeHandle; + private final HybridData mHybridData; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; - private static native long nativeCreate( + @DoNotStrip + private static native HybridData initHybrid( int modelType, String modulePath, String tokenizerPath, float temperature, List dataFiles); - private static native void nativeDestroy(long nativeHandle); - /** * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and * dataFiles. @@ -60,7 +61,7 @@ public LlmModule( throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } - mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles); + mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles); } /** @@ -106,16 +107,7 @@ public LlmModule(LlmModuleConfig config) { } public void resetNative() { - if (mNativeHandle != 0) { - nativeDestroy(mNativeHandle); - mNativeHandle = 0; - } - } - - @Override - protected void finalize() throws Throwable { - resetNative(); - super.finalize(); + mHybridData.resetNative(); } /** @@ -158,12 +150,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo); - } - - private static native int nativeGenerate( - long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo); + public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo); /** * Start generating tokens from the module. @@ -219,15 +206,14 @@ public int generate( */ @Experimental public long prefillImages(int[] image, int width, int height, int channels) { - int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels); + int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private static native int nativeAppendImagesInput( - long nativeHandle, int[] image, int width, int height, int channels); + private native int appendImagesInput(int[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given images input. @@ -242,16 +228,15 @@ private static native int nativeAppendImagesInput( */ @Experimental public long prefillImages(float[] image, int width, int height, int channels) { - int nativeResult = - nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels); + int nativeResult = appendNormalizedImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private static native int nativeAppendNormalizedImagesInput( - long nativeHandle, float[] image, int width, int height, int channels); + private native int appendNormalizedImagesInput( + float[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given audio input. @@ -266,15 +251,14 @@ private static native int nativeAppendNormalizedImagesInput( */ @Experimental public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames); + int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private static native int nativeAppendAudioInput( - long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames); + private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given audio input. @@ -289,16 +273,14 @@ private static native int nativeAppendAudioInput( */ @Experimental public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = - nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames); + int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private static native int nativeAppendAudioInputFloat( - long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames); + private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given raw audio input. @@ -313,16 +295,15 @@ private static native int nativeAppendAudioInputFloat( */ @Experimental public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - int nativeResult = - nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples); + int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private static native int nativeAppendRawAudioInput( - long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples); + private native int appendRawAudioInput( + byte[] audio, int batch_size, int n_channels, int n_samples); /** * Prefill a multimodal Module with the given text input. @@ -334,7 +315,7 @@ private static native int nativeAppendRawAudioInput( */ @Experimental public long prefillPrompt(String prompt) { - int nativeResult = nativeAppendTextInput(mNativeHandle, prompt); + int nativeResult = appendTextInput(prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -342,30 +323,20 @@ public long prefillPrompt(String prompt) { } // returns status - private static native int nativeAppendTextInput(long nativeHandle, String prompt); + private native int appendTextInput(String prompt); /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * *

The startPos will be reset to 0. */ - public void resetContext() { - nativeResetContext(mNativeHandle); - } - - private static native void nativeResetContext(long nativeHandle); + public native void resetContext(); /** Stop current generate() before it finishes. */ - public void stop() { - nativeStop(mNativeHandle); - } - - private static native void nativeStop(long nativeHandle); + @DoNotStrip + public native void stop(); /** Force loading the module. Otherwise the model is loaded during first generate(). */ - public int load() { - return nativeLoad(mNativeHandle); - } - - private static native int nativeLoad(long nativeHandle); + @DoNotStrip + public native int load(); } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 636e4de0b36..93c6e111a4d 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -717,10 +717,10 @@ Java_org_pytorch_executorch_Module_nativeEtdump( } // extern "C" #ifdef EXECUTORCH_BUILD_LLAMA_JNI -extern void register_natives_for_llm(JNIEnv* env); +extern void register_natives_for_llm(); #else // No op if we don't build LLM -void register_natives_for_llm(JNIEnv* /* env */) {} +void register_natives_for_llm() {} #endif #ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING @@ -785,7 +785,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { // Register native methods register_natives_for_module(env); - register_natives_for_llm(env); + register_natives_for_llm(); register_natives_for_runtime(env); register_natives_for_training(env); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e26b40f4f4b..888e09e7989 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -6,12 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include - #include #include #include -#include #include #include #include @@ -33,6 +30,9 @@ #include #endif +#include +#include + #if defined(EXECUTORCH_BUILD_QNN) #include #endif @@ -45,10 +45,6 @@ namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; namespace { - -// Global JavaVM pointer for obtaining JNIEnv in callbacks -JavaVM* g_jvm = nullptr; - bool utf8_check_validity(const char* str, size_t length) { for (size_t i = 0; i < length; ++i) { uint8_t byte = static_cast(str[i]); @@ -83,70 +79,47 @@ bool utf8_check_validity(const char* str, size_t length) { } std::string token_buffer; +} // namespace -// Helper to convert jstring to std::string -std::string jstring_to_string(JNIEnv* env, jstring jstr) { - if (jstr == nullptr) { - return ""; - } - const char* chars = env->GetStringUTFChars(jstr, nullptr); - if (chars == nullptr) { - return ""; - } - std::string result(chars); - env->ReleaseStringUTFChars(jstr, chars); - return result; -} - -// Helper to convert Java List to std::vector -std::vector jlist_to_string_vector(JNIEnv* env, jobject jlist) { - std::vector result; - if (jlist == nullptr) { - return result; - } - - jclass list_class = env->FindClass("java/util/List"); - if (list_class == nullptr) { - env->ExceptionClear(); - return result; - } +namespace executorch_jni { - jmethodID size_method = env->GetMethodID(list_class, "size", "()I"); - jmethodID get_method = - env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;"); +class ExecuTorchLlmCallbackJni + : public facebook::jni::JavaClass { + public: + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/extension/llm/LlmCallback;"; - if (size_method == nullptr || get_method == nullptr) { - env->ExceptionClear(); - env->DeleteLocalRef(list_class); - return result; - } + void onResult(std::string result) const { + static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); + static const auto method = + cls->getMethod)>("onResult"); - jint size = env->CallIntMethod(jlist, size_method); - for (jint i = 0; i < size; ++i) { - jobject str_obj = env->CallObjectMethod(jlist, get_method, i); - if (str_obj != nullptr) { - result.push_back(jstring_to_string(env, static_cast(str_obj))); - env->DeleteLocalRef(str_obj); + token_buffer += result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; } + result = token_buffer; + token_buffer = ""; + facebook::jni::local_ref s = facebook::jni::make_jstring(result); + method(self(), s); } - env->DeleteLocalRef(list_class); - return result; -} - -} // namespace - -namespace executorch_jni { - -// Model type category constants -constexpr int MODEL_TYPE_CATEGORY_LLM = 1; -constexpr int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; -constexpr int MODEL_TYPE_MEDIATEK_LLAMA = 3; -constexpr int MODEL_TYPE_QNN_LLAMA = 4; + void onStats(const llm::Stats& result) const { + static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); + static const auto on_stats_method = + cls->getMethod)>("onStats"); + on_stats_method( + self(), + facebook::jni::make_jstring( + executorch::extension::llm::stats_to_json_string(result))); + } +}; -// Native handle class that holds the runner state -class ExecuTorchLlmNative { - public: +class ExecuTorchLlmJni : public facebook::jni::HybridClass { + private: + friend HybridBase; float temperature_ = 0.0f; int model_type_category_; std::unique_ptr runner_; @@ -154,13 +127,37 @@ class ExecuTorchLlmNative { multi_modal_runner_; std::vector prefill_inputs_; - ExecuTorchLlmNative( - JNIEnv* env, + public: + constexpr static auto kJavaDescriptor = + "Lorg/pytorch/executorch/extension/llm/LlmModule;"; + + constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; + constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; + constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; + constexpr static int MODEL_TYPE_QNN_LLAMA = 4; + + static facebook::jni::local_ref initHybrid( + facebook::jni::alias_ref, jint model_type_category, - jstring model_path, - jstring tokenizer_path, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path, jfloat temperature, - jobject data_files) { + facebook::jni::alias_ref::javaobject> + data_files) { + return makeCxxInstance( + model_type_category, + model_path, + tokenizer_path, + temperature, + data_files); + } + + ExecuTorchLlmJni( + jint model_type_category, + facebook::jni::alias_ref model_path, + facebook::jni::alias_ref tokenizer_path, + jfloat temperature, + facebook::jni::alias_ref data_files = nullptr) { temperature_ = temperature; #if defined(ET_USE_THREADPOOL) // Reserve 1 thread for the main thread. @@ -174,30 +171,44 @@ class ExecuTorchLlmNative { #endif model_type_category_ = model_type_category; - std::string model_path_str = jstring_to_string(env, model_path); - std::string tokenizer_path_str = jstring_to_string(env, tokenizer_path); - std::vector data_files_vector = - jlist_to_string_vector(env, data_files); - + std::vector data_files_vector; if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_ = llm::create_multimodal_runner( - model_path_str.c_str(), llm::load_tokenizer(tokenizer_path_str)); + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString())); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); + } + } runner_ = executorch::extension::llm::create_text_llm_runner( - model_path_str, llm::load_tokenizer(tokenizer_path_str), data_files_vector); + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector); #if defined(EXECUTORCH_BUILD_QNN) } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { std::unique_ptr module = std::make_unique< executorch::extension::Module>( - model_path_str.c_str(), + model_path->toStdString().c_str(), data_files_vector, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); std::string decoder_model = "llama3"; // use llama3 for now runner_ = std::make_unique>( // QNN runner std::move(module), decoder_model.c_str(), - model_path_str.c_str(), - tokenizer_path_str.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), "", ""); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; @@ -205,530 +216,249 @@ class ExecuTorchLlmNative { #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( - model_path_str.c_str(), tokenizer_path_str.c_str()); + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); // Interpret the model type as LLM model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif } } -}; -// Helper class for callback invocation -class CallbackHelper { - public: - CallbackHelper(JNIEnv* env, jobject callback) - : env_(env), callback_(nullptr), callback_class_(nullptr) { - if (callback != nullptr) { - callback_ = env_->NewGlobalRef(callback); - jclass local_class = env_->GetObjectClass(callback); - callback_class_ = static_cast(env_->NewGlobalRef(local_class)); - env_->DeleteLocalRef(local_class); - on_result_method_ = env_->GetMethodID( - callback_class_, "onResult", "(Ljava/lang/String;)V"); - on_stats_method_ = - env_->GetMethodID(callback_class_, "onStats", "(Ljava/lang/String;)V"); + jint generate( + facebook::jni::alias_ref prompt, + jint seq_len, + facebook::jni::alias_ref callback, + jboolean echo) { + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + std::vector inputs = prefill_inputs_; + prefill_inputs_.clear(); + if (!prompt->toStdString().empty()) { + inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); + } + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = temperature_, + }; + multi_modal_runner_->generate( + std::move(inputs), + config, + [callback](const std::string& result) { callback->onResult(result); }, + [callback](const llm::Stats& result) { callback->onStats(result); }); + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = temperature_, + }; + runner_->generate( + prompt->toStdString(), + config, + [callback](std::string result) { callback->onResult(result); }, + [callback](const llm::Stats& result) { callback->onStats(result); }); } + return 0; } - ~CallbackHelper() { - if (g_jvm == nullptr) { - return; - } - // Get the current JNIEnv (might be different thread) - JNIEnv* env = nullptr; - int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); - if (status == JNI_EDETACHED) { - g_jvm->AttachCurrentThread(&env, nullptr); - } - if (env != nullptr) { - if (callback_ != nullptr) { - env->DeleteGlobalRef(callback_); - } - if (callback_class_ != nullptr) { - env->DeleteGlobalRef(callback_class_); - } - } + // Returns status_code + // Contract is valid within an AAR (JNI + corresponding Java code) + jint append_text_input(facebook::jni::alias_ref prompt) { + prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); + return 0; } - void onResult(const std::string& result) { - JNIEnv* env = getEnv(); - if (env == nullptr || callback_ == nullptr || on_result_method_ == nullptr) { - return; + // Returns status_code + jint append_images_input( + facebook::jni::alias_ref image, + jint width, + jint height, + jint channels) { + std::vector images; + if (image == nullptr) { + return static_cast(Error::EndOfMethod); } - - std::string current_result = result; - token_buffer += current_result; - if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { - ET_LOG( - Info, "Current token buffer is not valid UTF-8. Waiting for more."); - return; + auto image_size = image->size(); + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + image->getRegion(0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jint[i]; + } + llm::Image image_runner{std::move(image_data), width, height, channels}; + prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); } - current_result = token_buffer; - token_buffer = ""; - jstring jstr = env->NewStringUTF(current_result.c_str()); - if (jstr != nullptr) { - env->CallVoidMethod(callback_, on_result_method_, jstr); - env->DeleteLocalRef(jstr); - } + return 0; } - void onStats(const llm::Stats& stats) { - JNIEnv* env = getEnv(); - if (env == nullptr || callback_ == nullptr || on_stats_method_ == nullptr) { - return; + // Returns status_code + jint append_normalized_images_input( + facebook::jni::alias_ref image, + jint width, + jint height, + jint channels) { + std::vector images; + if (image == nullptr) { + return static_cast(Error::EndOfMethod); } - - std::string stats_json = - executorch::extension::llm::stats_to_json_string(stats); - jstring jstr = env->NewStringUTF(stats_json.c_str()); - if (jstr != nullptr) { - env->CallVoidMethod(callback_, on_stats_method_, jstr); - env->DeleteLocalRef(jstr); + auto image_size = image->size(); + if (image_size != 0) { + std::vector image_data_jfloat(image_size); + std::vector image_data(image_size); + image->getRegion(0, image_size, image_data_jfloat.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jfloat[i]; + } + llm::Image image_runner{std::move(image_data), width, height, channels}; + prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); } + + return 0; } - private: - JNIEnv* getEnv() { - if (g_jvm == nullptr) { - return nullptr; + // Returns status_code + jint append_audio_input( + facebook::jni::alias_ref data, + jint batch_size, + jint n_bins, + jint n_frames) { + if (data == nullptr) { + return static_cast(Error::EndOfMethod); } - JNIEnv* env = nullptr; - int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); - if (status == JNI_EDETACHED) { - g_jvm->AttachCurrentThread(&env, nullptr); + auto data_size = data->size(); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + data->getRegion(0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = data_jbyte[i]; + } + llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; + prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); } - return env; - } - - JNIEnv* env_; - jobject callback_; - jclass callback_class_ = nullptr; - jmethodID on_result_method_ = nullptr; - jmethodID on_stats_method_ = nullptr; -}; - -} // namespace executorch_jni - -extern "C" { - -JNIEXPORT jlong JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate( - JNIEnv* env, - jobject /* this */, - jint model_type_category, - jstring model_path, - jstring tokenizer_path, - jfloat temperature, - jobject data_files) { - auto* native = new executorch_jni::ExecuTorchLlmNative( - env, model_type_category, model_path, tokenizer_path, temperature, data_files); - return reinterpret_cast(native); -} - -JNIEXPORT void JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy( - JNIEnv* /* env */, - jobject /* this */, - jlong native_handle) { - if (native_handle != 0) { - auto* native = - reinterpret_cast(native_handle); - delete native; + return 0; } -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jstring prompt, - jint seq_len, - jobject callback, - jboolean echo) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; - } - - std::string prompt_str = jstring_to_string(env, prompt); - - // Create a shared callback helper for use in lambdas - auto callback_helper = - std::make_shared(env, callback); - if (native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = native->prefill_inputs_; - native->prefill_inputs_.clear(); - if (!prompt_str.empty()) { - inputs.emplace_back(llm::MultimodalInput{prompt_str}); + // Returns status_code + jint append_audio_input_float( + facebook::jni::alias_ref data, + jint batch_size, + jint n_bins, + jint n_frames) { + if (data == nullptr) { + return static_cast(Error::EndOfMethod); } - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = native->temperature_, - }; - native->multi_modal_runner_->generate( - std::move(inputs), - config, - [callback_helper](const std::string& result) { - callback_helper->onResult(result); - }, - [callback_helper](const llm::Stats& result) { - callback_helper->onStats(result); - }); - } else if ( - native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = native->temperature_, - }; - native->runner_->generate( - prompt_str, - config, - [callback_helper](std::string result) { - callback_helper->onResult(result); - }, - [callback_helper](const llm::Stats& result) { - callback_helper->onStats(result); - }); - } - return 0; -} - -JNIEXPORT void JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop( - JNIEnv* /* env */, - jobject /* this */, - jlong native_handle) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return; - } - - if (native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { - native->multi_modal_runner_->stop(); - } else if ( - native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_LLM) { - native->runner_->stop(); - } -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad( - JNIEnv* env, - jobject /* this */, - jlong native_handle) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; + auto data_size = data->size(); + if (data_size != 0) { + std::vector data_jfloat(data_size); + std::vector data_f(data_size); + data->getRegion(0, data_size, data_jfloat.data()); + for (int i = 0; i < data_size; i++) { + data_f[i] = data_jfloat[i]; + } + llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; + prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; } - int result = -1; - std::stringstream ss; - - if (native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { - result = static_cast(native->multi_modal_runner_->load()); - if (result != 0) { - ss << "Failed to load multimodal runner: [" << result << "]"; + // Returns status_code + jint append_raw_audio_input( + facebook::jni::alias_ref data, + jint batch_size, + jint n_channels, + jint n_samples) { + if (data == nullptr) { + return static_cast(Error::EndOfMethod); } - } else if ( - native->model_type_category_ == - executorch_jni::MODEL_TYPE_CATEGORY_LLM) { - result = static_cast(native->runner_->load()); - if (result != 0) { - ss << "Failed to load llm runner: [" << result << "]"; + auto data_size = data->size(); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + data->getRegion(0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = data_jbyte[i]; + } + llm::RawAudio audio{ + std::move(data_u8), batch_size, n_channels, n_samples}; + prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); } - } else { - ss << "Invalid model type category: " << native->model_type_category_ - << ". Valid values are: " - << executorch_jni::MODEL_TYPE_CATEGORY_LLM << " or " - << executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL; - } - if (result != 0) { - executorch::jni_helper::throwExecutorchException( - env, static_cast(Error::InvalidArgument), ss.str().c_str()); - } - return result; // 0 on success to keep backward compatibility -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jstring prompt) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; - } - - std::string prompt_str = jstring_to_string(env, prompt); - native->prefill_inputs_.emplace_back(llm::MultimodalInput{prompt_str}); - return 0; -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jintArray image, - jint width, - jint height, - jint channels) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; + return 0; } - if (image == nullptr) { - return static_cast(Error::EndOfMethod); - } - - jsize image_size = env->GetArrayLength(image); - if (image_size != 0) { - std::vector image_data_jint(image_size); - std::vector image_data(image_size); - env->GetIntArrayRegion(image, 0, image_size, image_data_jint.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = static_cast(image_data_jint[i]); + void stop() { + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + multi_modal_runner_->stop(); + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + runner_->stop(); } - llm::Image image_runner{std::move(image_data), width, height, channels}; - native->prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); } - return 0; -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jfloatArray image, - jint width, - jint height, - jint channels) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; - } - - if (image == nullptr) { - return static_cast(Error::EndOfMethod); - } - - jsize image_size = env->GetArrayLength(image); - if (image_size != 0) { - std::vector image_data_jfloat(image_size); - std::vector image_data(image_size); - env->GetFloatArrayRegion(image, 0, image_size, image_data_jfloat.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jfloat[i]; + void reset_context() { + if (runner_ != nullptr) { + runner_->reset(); } - llm::Image image_runner{std::move(image_data), width, height, channels}; - native->prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); - } - - return 0; -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jbyteArray data, - jint batch_size, - jint n_bins, - jint n_frames) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; - } - - if (data == nullptr) { - return static_cast(Error::EndOfMethod); - } - - jsize data_size = env->GetArrayLength(data); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = static_cast(data_jbyte[i]); + if (multi_modal_runner_ != nullptr) { + multi_modal_runner_->reset(); } - llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; - native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); - } - return 0; -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jfloatArray data, - jint batch_size, - jint n_bins, - jint n_frames) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; } - if (data == nullptr) { - return static_cast(Error::EndOfMethod); - } + jint load() { + int result = -1; + std::stringstream ss; - jsize data_size = env->GetArrayLength(data); - if (data_size != 0) { - std::vector data_jfloat(data_size); - std::vector data_f(data_size); - env->GetFloatArrayRegion(data, 0, data_size, data_jfloat.data()); - for (int i = 0; i < data_size; i++) { - data_f[i] = data_jfloat[i]; + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + result = static_cast(multi_modal_runner_->load()); + if (result != 0) { + ss << "Failed to load multimodal runner: [" << result << "]"; + } + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + result = static_cast(runner_->load()); + if (result != 0) { + ss << "Failed to load llm runner: [" << result << "]"; + } + } else { + ss << "Invalid model type category: " << model_type_category_ + << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " + << MODEL_TYPE_CATEGORY_MULTIMODAL; } - llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; - native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); - } - return 0; -} - -JNIEXPORT jint JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput( - JNIEnv* env, - jobject /* this */, - jlong native_handle, - jbyteArray data, - jint batch_size, - jint n_channels, - jint n_samples) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return -1; - } - - if (data == nullptr) { - return static_cast(Error::EndOfMethod); - } - - jsize data_size = env->GetArrayLength(data); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = static_cast(data_jbyte[i]); + if (result != 0) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); } - llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples}; - native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); - } - return 0; -} - -JNIEXPORT void JNICALL -Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext( - JNIEnv* /* env */, - jobject /* this */, - jlong native_handle) { - auto* native = - reinterpret_cast(native_handle); - if (native == nullptr) { - return; - } - - if (native->runner_ != nullptr) { - native->runner_->reset(); - } - if (native->multi_modal_runner_ != nullptr) { - native->multi_modal_runner_->reset(); - } -} - -} // extern "C" - -void register_natives_for_llm(JNIEnv* env) { - // Store the JavaVM for later use in callbacks - env->GetJavaVM(&g_jvm); - - jclass llm_module_class = - env->FindClass("org/pytorch/executorch/extension/llm/LlmModule"); - if (llm_module_class == nullptr) { - ET_LOG(Error, "Failed to find LlmModule class"); - env->ExceptionClear(); - return; + return result; // 0 on success to keep backward compatibility + } + + static void registerNatives() { + registerHybrid({ + makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), + makeNativeMethod("generate", ExecuTorchLlmJni::generate), + makeNativeMethod("stop", ExecuTorchLlmJni::stop), + makeNativeMethod("load", ExecuTorchLlmJni::load), + makeNativeMethod( + "appendImagesInput", ExecuTorchLlmJni::append_images_input), + makeNativeMethod( + "appendNormalizedImagesInput", + ExecuTorchLlmJni::append_normalized_images_input), + makeNativeMethod( + "appendAudioInput", ExecuTorchLlmJni::append_audio_input), + makeNativeMethod( + "appendAudioInputFloat", + ExecuTorchLlmJni::append_audio_input_float), + makeNativeMethod( + "appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input), + makeNativeMethod( + "appendTextInput", ExecuTorchLlmJni::append_text_input), + makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), + }); } +}; - // clang-format off - static const JNINativeMethod methods[] = { - {"nativeCreate", - "(ILjava/lang/String;Ljava/lang/String;FLjava/util/List;)J", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate)}, - {"nativeDestroy", "(J)V", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy)}, - {"nativeGenerate", - "(JLjava/lang/String;ILorg/pytorch/executorch/extension/llm/LlmCallback;Z)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate)}, - {"nativeStop", "(J)V", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop)}, - {"nativeLoad", "(J)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad)}, - {"nativeAppendTextInput", "(JLjava/lang/String;)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput)}, - {"nativeAppendImagesInput", "(J[IIII)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput)}, - {"nativeAppendNormalizedImagesInput", "(J[FIII)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput)}, - {"nativeAppendAudioInput", "(J[BIII)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput)}, - {"nativeAppendAudioInputFloat", "(J[FIII)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat)}, - {"nativeAppendRawAudioInput", "(J[BIII)I", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput)}, - {"nativeResetContext", "(J)V", - reinterpret_cast( - Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext)}, - }; - // clang-format on - - int num_methods = sizeof(methods) / sizeof(methods[0]); - int result = env->RegisterNatives(llm_module_class, methods, num_methods); - if (result != JNI_OK) { - ET_LOG(Error, "Failed to register native methods for LlmModule"); - } +} // namespace executorch_jni - env->DeleteLocalRef(llm_module_class); +void register_natives_for_llm() { + executorch_jni::ExecuTorchLlmJni::registerNatives(); }