|
23 | 23 | #include <executorch/extension/data_loader/buffer_data_loader.h>
|
24 | 24 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
25 | 25 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
| 26 | +#include <executorch/extension/module/module.h> |
26 | 27 | #include <executorch/extension/threadpool/threadpool.h>
|
27 | 28 | #include <executorch/runtime/backend/interface.h>
|
28 | 29 | #include <executorch/runtime/core/data_loader.h>
|
@@ -442,11 +443,12 @@ inline std::unique_ptr<Module> load_module_from_file(
|
442 | 443 |
|
443 | 444 | static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
|
444 | 445 |
|
445 |
| -struct PyBundledModule final { |
| 446 | +struct PyBundledModule : public BundledModule { |
446 | 447 | explicit PyBundledModule(
|
447 | 448 | const py::bytes& buffer,
|
448 | 449 | uint32_t bundled_input_pool_size)
|
449 |
| - : bundled_program_ptr_(buffer), |
| 450 | + : BundledModule(buffer.cast<std::string_view>().data()), |
| 451 | + bundled_program_ptr_(buffer), |
450 | 452 | program_ptr_(static_cast<const void*>(
|
451 | 453 | bundled_program_flatbuffer::GetBundledProgram(
|
452 | 454 | get_bundled_program_ptr())
|
@@ -840,22 +842,26 @@ struct PyModule final {
|
840 | 842 | size_t testset_idx,
|
841 | 843 | double rtol = 1e-5,
|
842 | 844 | double atol = 1e-8) {
|
843 |
| - const void* bundled_program_ptr = m.get_bundled_program_ptr(); |
844 |
| - auto& method = module_->get_method(method_name); |
845 |
| - Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( |
846 |
| - method, bundled_program_ptr, testset_idx); |
| 845 | + auto status = m.load_bundled_input(method_name, testset_idx); |
847 | 846 | THROW_IF_ERROR(
|
848 | 847 | status,
|
849 |
| - "load_bundled_input failed with status 0x%" PRIx32, |
| 848 | + "Load input from bundled to method failed with status %" PRIu32, |
850 | 849 | static_cast<uint32_t>(status));
|
851 |
| - py::list outputs = plan_execute(method_name); |
852 |
| - status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( |
853 |
| - method, bundled_program_ptr, testset_idx, rtol, atol); |
| 850 | + |
| 851 | + auto outputs = m.Module::execute(method_name); |
| 852 | + |
| 853 | + THROW_IF_ERROR( |
| 854 | + outputs.error(), |
| 855 | + "Execution failed with status 0x%" PRIx32, |
| 856 | + static_cast<uint32_t>(outputs.error())); |
| 857 | + |
| 858 | + status = m.verify_method_outputs(method_name, testset_idx, rtol, atol); |
854 | 859 | THROW_IF_ERROR(
|
855 | 860 | status,
|
856 | 861 | "Result verification failed with status %" PRIu32,
|
857 | 862 | static_cast<uint32_t>(status));
|
858 |
| - return outputs; |
| 863 | + |
| 864 | + return get_outputs_as_py_list(outputs.get()); |
859 | 865 | }
|
860 | 866 |
|
861 | 867 | py::list plan_execute(
|
|
0 commit comments