Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance ONNX importer with better error handling and output support. #3971

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 57 additions & 37 deletions projects/onnx_c_importer/import-onnx-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
#include "llvm/Support/raw_ostream.h"

#include "OnnxImporter.h"

#include "onnx/onnx_pb.h"

#include <fstream>
#include <iostream>
#include <memory>

using namespace llvm;
using namespace torch_mlir_onnx;

// Encapsulates MLIR context and module management
struct MlirState {
MlirState() {
context = mlirContextCreateWithThreading(false);
Expand All @@ -42,62 +43,81 @@ struct MlirState {
};

int main(int argc, char **argv) {
// Define command-line options
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));
static cl::opt<std::string> outputFilename(
"o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-"));

static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));

// Initialize LLVM and parse command-line options
InitLLVM y(argc, argv);
cl::ParseCommandLineOptions(argc, argv, "torch-mlir-onnx-import-c");

// Open the input as an istream because that is what protobuf likes.
std::unique_ptr<std::ifstream> alloced_input_stream;
std::istream *input_stream = nullptr;
// Open the input file stream
std::unique_ptr<std::ifstream> allocedInputStream;
std::istream *inputStream = nullptr;
if (inputFilename == "-") {
errs() << "(parsing from stdin)\n";
input_stream = &std::cin;
errs() << "(Parsing from stdin)\n";
inputStream = &std::cin;
} else {
alloced_input_stream = std::make_unique<std::ifstream>(
allocedInputStream = std::make_unique<std::ifstream>(
inputFilename, std::ios::in | std::ios::binary);
if (!*alloced_input_stream) {
errs() << "error: could not open input file " << inputFilename << "\n";
return 1;
if (!allocedInputStream->is_open()) {
errs() << "Error: Could not open input file: " << inputFilename << "\n";
return EXIT_FAILURE;
}
input_stream = alloced_input_stream.get();
inputStream = allocedInputStream.get();
}

// Parse the model proto.
ModelInfo model_info;
if (!model_info.model_proto().ParseFromIstream(input_stream)) {
errs() << "Failed to parse ONNX ModelProto from " << inputFilename << "\n";
return 2;
// Parse the ONNX model proto
ModelInfo modelInfo;
if (!modelInfo.model_proto().ParseFromIstream(inputStream)) {
errs() << "Error: Failed to parse ONNX ModelProto from " << inputFilename << "\n";
return EXIT_FAILURE;
}

if (failed(model_info.Initialize())) {
errs() << "error: Import failure: " << model_info.error_message() << "\n";
model_info.DebugDumpProto();
return 3;
// Initialize model information
if (failed(modelInfo.Initialize())) {
errs() << "Error: Import failure: " << modelInfo.error_message() << "\n";
modelInfo.DebugDumpProto();
return EXIT_FAILURE;
}
model_info.DebugDumpProto();
modelInfo.DebugDumpProto();

// Create MLIR state and context cache
MlirState ownedState;
ContextCache contextCache(modelInfo, ownedState.context);

// Import.
MlirState owned_state;
ContextCache cc(model_info, owned_state.context);
NodeImporter importer(model_info.main_graph(), cc,
mlirModuleGetOperation(owned_state.module));
// Import the ONNX graph into MLIR
NodeImporter importer(
modelInfo.main_graph(), contextCache, mlirModuleGetOperation(ownedState.module));
if (failed(importer.DefineFunction())) {
errs() << "error: Could not define MLIR function for graph: "
<< model_info.error_message() << "\n";
return 4;
errs() << "Error: Could not define MLIR function for graph: "
<< modelInfo.error_message() << "\n";
return EXIT_FAILURE;
}
if (failed(importer.ImportAll())) {
errs() << "error: Could not import one or more graph nodes: "
<< model_info.error_message() << "\n";
return 5;
errs() << "Error: Could not import one or more graph nodes: "
<< modelInfo.error_message() << "\n";
return EXIT_FAILURE;
}

// Dump the imported MLIR module
importer.DebugDumpModule();

return 0;
// Optional: Save the output MLIR module to a file
if (outputFilename != "-") {
std::ofstream outFile(outputFilename, std::ios::out);
if (!outFile.is_open()) {
errs() << "Error: Could not open output file: " << outputFilename << "\n";
return EXIT_FAILURE;
}
mlirOperationPrint(mlirModuleGetOperation(ownedState.module), outFile);
outs() << "Successfully saved MLIR module to " << outputFilename << "\n";
} else {
outs() << "MLIR module processing complete. Output not saved to a file.\n";
}

return EXIT_SUCCESS;
}