From 63bc5ec109f7d1aec4d519fe4273a5ba6410eaa9 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Tue, 20 Jun 2023 20:24:53 -0700 Subject: [PATCH 01/12] add DisposableElementsAttr offset Signed-off-by: Soren Lassen --- .../ElementsAttr/DisposableElementsAttr.cpp | 5 ++++- .../ElementsAttr/DisposableElementsAttr.hpp | 4 +++- .../DisposableElementsAttributeStorage.hpp | 4 +++- .../ONNX/ElementsAttr/DisposablePool.cpp | 4 ++-- .../ONNX/ElementsAttr/DisposablePool.hpp | 2 +- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 22 +++++++++++-------- .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 2 +- 7 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index ae5ffaa589..728a54cf48 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -58,7 +58,7 @@ void widenArray( /*static*/ DisposableElementsAttr DisposableElementsAttr::create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, - const Buffer &buffer, Transformer transformer) { + const Buffer &buffer, uint64_t offset, Transformer transformer) { BType btype = btypeOfMlirType(type.getElementType()); assert((transformer != nullptr || wideBTypeOfBType(bufferBType) == wideBTypeOfBType(btype)) && @@ -68,6 +68,7 @@ DisposableElementsAttr DisposableElementsAttr::create(ShapedType type, type.getContext(), type, strides, bufferBType, btype, isContiguous, id); DisposableElementsAttributeStorage &s = *a.getImpl(); s.buffer = buffer; + s.offset = offset; s.transformer = std::move(transformer); return a; } @@ -98,6 +99,8 @@ auto DisposableElementsAttr::getBuffer() const -> const Buffer & { return getImpl()->buffer; } +uint64_t DisposableElementsAttr::getOffset() const { return getImpl()->offset; } + auto DisposableElementsAttr::getTransformer() const -> const Transformer & { assert(!isDisposed()); return getImpl()->transformer; diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index 2f8912f577..8308e9ea6a 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -118,7 +118,7 @@ class DisposableElementsAttr // created instance. static DisposableElementsAttr create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, const Buffer &buffer, - Transformer transformer); + uint64_t offset, Transformer transformer); // Clears the buffer payload shared_ptr which decreases the reference count // and, if it reaches zero, frees or closes the underlying MemoryBuffer's @@ -154,6 +154,8 @@ class DisposableElementsAttr const Buffer &getBuffer() const; + uint64_t getOffset() const; + const Transformer &getTransformer() const; bool isContiguous() const; diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp index 23f4c5f9d5..11fc23805b 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp @@ -37,7 +37,7 @@ struct DisposableElementsAttributeStorage : public AttributeStorage { onnx_mlir::BType bufferBType, onnx_mlir::BType btype, bool isContiguous, size_t id) : type(type), strides(strides), bufferBType(bufferBType), btype(btype), - isContiguous(isContiguous), id(id) {} + isContiguous(isContiguous), id(id), offset(0) {} // Equality and hashKey are engineered to defeat the storage uniquer. // We don't want uniqueing because we can't compare transformers for equality @@ -103,6 +103,8 @@ struct DisposableElementsAttributeStorage : public AttributeStorage { // file closed) when no one points to it anymore. Buffer buffer; + uint64_t offset; + // Reads the buffer elements to WideNums corresponding to type's // element type. Is null if data is not transformed. // In this case the buffer data type and the type's element type must promote diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp index 9342b9b286..60b87c0b1b 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp @@ -24,12 +24,12 @@ DisposablePool::~DisposablePool() {} ElementsAttr DisposablePool::createElementsAttr(ShapedType type, BType bufferBType, ArrayRef strides, - const mlir::DisposableElementsAttr::Buffer &buffer, + const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, DisposableElementsAttr::Transformer transformer) { static std::atomic counter{0}; size_t id = ++counter; auto disposable = DisposableElementsAttr::create( - type, id, bufferBType, strides, buffer, std::move(transformer)); + type, id, bufferBType, strides, buffer, offset, std::move(transformer)); if (insert(disposable)) { return disposable; } else { diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp index 9a538c06ef..08fdff8af8 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp @@ -57,7 +57,7 @@ class DisposablePool : public mlir::DialectInterface::Base { // otherwise returns conversion to DenseElementsAttr. mlir::ElementsAttr createElementsAttr(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, - const mlir::DisposableElementsAttr::Buffer &buffer, + const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, mlir::DisposableElementsAttr::Transformer transformer); // Disposes every DisposableElementsAttr in the pool which is unreachable diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index e0e23f03df..809bb8ea7d 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -58,6 +58,7 @@ struct ElementsAttrBuilder::ElementsProperties { BType bufferBType; SmallVector strides; std::shared_ptr buffer; + uint64_t offset; const Transformer &transformer; }; @@ -79,7 +80,7 @@ DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( return nullptr; ElementsProperties props = getElementsProperties(dense); ElementsAttr created = create(dense.getType(), props.bufferBType, - props.strides, props.buffer, props.transformer); + props.strides, props.buffer, props.offset, props.transformer); // Check for race condition where disposablePool became inactive since we // checked, in which case it returns a DenseElementsAttr which we don't // want. @@ -328,7 +329,7 @@ ElementsAttr ElementsAttrBuilder::castElementType( : composeTransforms(props.transformer, functionTransformer(wideCaster(oldWideType, newWideType))); return create(newType, props.bufferBType, props.strides, props.buffer, - std::move(transformer)); + props.offset, std::move(transformer)); } namespace { @@ -353,7 +354,7 @@ ElementsAttr ElementsAttrBuilder::transpose( ShapedType transposedType = type.clone(transposedShape); auto transposedStrides = transposeDims(props.strides, perm); return create(transposedType, props.bufferBType, transposedStrides, - props.buffer, props.transformer); + props.buffer, props.offset, props.transformer); } ElementsAttr ElementsAttrBuilder::reshape( @@ -369,7 +370,7 @@ ElementsAttr ElementsAttrBuilder::reshape( if (auto reshapedStrides = reshapeStrides(shape, props.strides, reshapedShape)) return create(reshapedType, props.bufferBType, *reshapedStrides, - props.buffer, props.transformer); + props.buffer, props.offset, props.transformer); auto disp = elms.dyn_cast(); assert(disp && "reshapeStrides() always succeeds for non-Disposable " @@ -399,7 +400,7 @@ ElementsAttr ElementsAttrBuilder::expand( ShapedType expandedType = type.clone(expandedShape); auto expandedStrides = expandStrides(props.strides, expandedShape); return create(expandedType, props.bufferBType, expandedStrides, props.buffer, - props.transformer); + props.offset, props.transformer); } namespace { @@ -836,6 +837,7 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) return {/*.bufferBType=*/disposable.getBufferBType(), /*.strides=*/{strides.begin(), strides.end()}, /*.buffer=*/disposable.getBuffer(), + /*.offset=*/disposable.getOffset(), /*.transformer=*/disposable.getTransformer()}; } else if (auto dense = elements.dyn_cast()) { ShapedType type = dense.getType(); @@ -848,6 +850,7 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) return {/*.bufferBType=*/btypeOfMlirType(type.getElementType()), /*.strides=*/{strides.begin(), strides.end()}, /*.buffer=*/getMemoryBuffer(dense), + /*.offset=*/0, /*.transformer=*/nullTransformer}; } // TODO: consider supporting more ElementsAttr types @@ -879,6 +882,7 @@ ElementsAttr ElementsAttrBuilder::doTransform( ElementsProperties props = getElementsProperties(elms); return create(transformedType, props.bufferBType, props.strides, props.buffer, + props.offset, composeTransforms(props.transformer, std::move(transformer))); } @@ -890,7 +894,7 @@ ElementsAttr ElementsAttrBuilder::expandAndTransform(ElementsAttr elms, expandStrides(props.strides, expandedTransformedType.getShape()); return create(expandedTransformedType, props.bufferBType, expandedStrides, - props.buffer, + props.buffer, props.offset, composeTransforms(props.transformer, std::move(transformer))); } @@ -906,15 +910,15 @@ ElementsAttr ElementsAttrBuilder::fromRawBytes( ElementsAttr ElementsAttrBuilder::createWithDefaultStrides(ShapedType type, BType bufferBType, std::unique_ptr membuf) { auto strides = getDefaultStrides(type.getShape()); - return create(type, bufferBType, strides, std::move(membuf)); + return create(type, bufferBType, strides, std::move(membuf), 0); } ElementsAttr ElementsAttrBuilder::create(ShapedType type, BType bufferBType, ArrayRef strides, - const std::shared_ptr &buffer, + const std::shared_ptr &buffer, uint64_t offset, Transformer transformer) { return disposablePool.createElementsAttr( - type, bufferBType, strides, buffer, std::move(transformer)); + type, bufferBType, strides, buffer, offset, std::move(transformer)); } } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 1dd7ea3dba..06a49128f5 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -233,7 +233,7 @@ class ElementsAttrBuilder { // Create a DisposableElementsAttr and put it in disposablePool. mlir::ElementsAttr create(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, - const std::shared_ptr &buffer, + const std::shared_ptr &buffer, uint64_t offset, Transformer transformer = nullptr); DisposablePool &disposablePool; From 412046f7c5f1fb52095db2bcd6342618c106c4b1 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Tue, 20 Jun 2023 22:33:28 -0700 Subject: [PATCH 02/12] begin with offset Signed-off-by: Soren Lassen --- .../ONNX/ElementsAttr/DisposableElementsAttr.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index 728a54cf48..34b5de9e81 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -127,7 +127,13 @@ unsigned DisposableElementsAttr::getBufferElementBytewidth() const { } int64_t DisposableElementsAttr::getNumBufferElements() const { - return getBuffer()->getBufferSize() / getBufferElementBytewidth(); + int64_t lastPos = 0; + for (auto [dimSize, stride] : llvm::zip(getShape(), getStrides())) { + if (dimSize == 0) + return 0; + lastPos += (dimSize - 1) * stride; + } + return lastPos + 1; } ArrayBuffer DisposableElementsAttr::getWideNums() const { @@ -253,7 +259,8 @@ void DisposableElementsAttr::readBytesAsWideNums( } ArrayRef DisposableElementsAttr::getBufferBytes() const { - return asArrayRef(getBuffer()->getBuffer()); + size_t numBytes = getNumBufferElements() * getBufferElementBytewidth(); + return asArrayRef(getBuffer()->getBuffer().substr(getOffset(), numBytes)); } ArrayBuffer DisposableElementsAttr::getBufferAsWideNums() const { From b75378180aaf67e735882ec031fd3157d9268d2a Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 06:59:01 -0700 Subject: [PATCH 03/12] check that mmap buffer identifier is file path Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectHelper.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 4f4d56a320..792911c533 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -61,16 +61,18 @@ std::unique_ptr readExternalData_LE( } } assert(!location.empty() && "missing external data location"); - SmallVector path(externalDataDir.begin(), externalDataDir.end()); - llvm::sys::path::append(path, location); + SmallVector pathBuf(externalDataDir.begin(), externalDataDir.end()); + llvm::sys::path::append(pathBuf, location); + StringRef path(pathBuf.data(), pathBuf.size()); auto bufferOrError = llvm::MemoryBuffer::getFileSlice( path, length, offset, /*IsVolatile=*/false); if (std::error_code ec = bufferOrError.getError()) { - std::string pathStr(path.data(), path.size()); - llvm::errs() << "Error " << ec.message() << " reading from file " << pathStr + llvm::errs() << "Error " << ec.message() << " reading from file " << path << ", offset=" << offset << ", length=" << length << "\n"; llvm_unreachable("llvm::MemoryBuffer::getFileSlice failed"); } + assert(bufferOrError.get()->getBufferIdentifier() == path && + "buffer identifier is file path"); return std::move(bufferOrError.get()); } From 7767a29ac8b97db9bce447a2e135040b66a00975 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 09:57:31 -0700 Subject: [PATCH 04/12] use the same mmap for all tensors from the same external data file Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectHelper.cpp | 61 ++---------- src/Builder/FrontendDialectHelper.hpp | 16 +++- src/Builder/FrontendDialectTransformer.cpp | 93 +++++++++++++++++-- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 14 ++- .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 8 +- src/Dialect/ONNX/ONNXAttributes.cpp | 3 +- 6 files changed, 126 insertions(+), 69 deletions(-) diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 792911c533..34d6fdc3bc 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -17,8 +17,6 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Endian.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" #include "llvm/Support/SwapByteOrder.h" #include "src/Dialect/ONNX/ElementsAttr/Arrays.hpp" @@ -33,49 +31,6 @@ namespace onnx_mlir { namespace { -// Parses unsigned number. -size_t parseOffsetOrLength(const std::string &value) { - char *end = nullptr; - size_t offsetOrLength = strtoull(value.c_str(), &end, 0); - assert(end != value.c_str() && "failed to parse offset or length"); - return offsetOrLength; -} - -// Reads external data from file location specified in tensor proto. -// The data is little endian encoded. -// See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md -std::unique_ptr readExternalData_LE( - const std::string &externalDataDir, const onnx::TensorProto &tp) { - std::string location; - uint64_t offset = 0; - uint64_t length = -1; // MemoryBuffer uses -1 to mean infinity - for (const onnx::StringStringEntryProto &entry : tp.external_data()) { - assert(entry.has_key() && "external_data entry must have key"); - assert(entry.has_value() && "external_data entry must have value"); - if (entry.key() == "location") { - location = entry.value(); - } else if (entry.key() == "offset") { - offset = parseOffsetOrLength(entry.value()); - } else if (entry.key() == "length") { - length = parseOffsetOrLength(entry.value()); - } - } - assert(!location.empty() && "missing external data location"); - SmallVector pathBuf(externalDataDir.begin(), externalDataDir.end()); - llvm::sys::path::append(pathBuf, location); - StringRef path(pathBuf.data(), pathBuf.size()); - auto bufferOrError = llvm::MemoryBuffer::getFileSlice( - path, length, offset, /*IsVolatile=*/false); - if (std::error_code ec = bufferOrError.getError()) { - llvm::errs() << "Error " << ec.message() << " reading from file " << path - << ", offset=" << offset << ", length=" << length << "\n"; - llvm_unreachable("llvm::MemoryBuffer::getFileSlice failed"); - } - assert(bufferOrError.get()->getBufferIdentifier() == path && - "buffer identifier is file path"); - return std::move(bufferOrError.get()); -} - template struct TransformValueToONNXData { static const google::protobuf::RepeatedField &data( @@ -158,15 +113,15 @@ T swappedBytes(T x) { template ElementsAttr createElementsAttrFromMemoryBuffer_LE( - RankedTensorType tensorType, std::unique_ptr membuf) { + RankedTensorType tensorType, const ExternalDataFileSlice &fileSlice) { MLIRContext *ctx = tensorType.getContext(); assert(tensorType.getElementType() == toMlirType(ctx)); if constexpr (shouldSwapLEBytes) { - ArrayRef array = asArrayRef(membuf->getBuffer()); + ArrayRef array = asArrayRef(fileSlice.getBufferSlice()); return createElmAttrFromArray(tensorType, array, swappedBytes); } else { return OnnxElementsAttrBuilder(ctx).fromMemoryBuffer( - tensorType, std::move(membuf)); + tensorType, fileSlice.file, fileSlice.offset, fileSlice.length); } } @@ -205,11 +160,12 @@ ElementsAttr createElmAttrFromProtoData(RankedTensorType tensorType, // Returns ElementsAttr with tp's data. template ElementsAttr createElmAttr(RankedTensorType tensorType, - const onnx::TensorProto &tp, const std::string &externalDataDir) { + const onnx::TensorProto &tp, + const ExternalDataFileSlice *externalDataFileSlice) { if (tp.has_data_location() && tp.data_location() == onnx::TensorProto::EXTERNAL) { return createElementsAttrFromMemoryBuffer_LE( - tensorType, readExternalData_LE(externalDataDir, tp)); + tensorType, *externalDataFileSlice); } if (tp.has_raw_data()) { return createElmAttrFromRawBytes_LE( @@ -238,7 +194,8 @@ ElementsAttr createStringElmAttr( } // namespace ElementsAttr onnxTensorProtoToElmAttr(MLIRContext *ctx, - const std::string &externalDataDir, const onnx::TensorProto &tp) { + const onnx::TensorProto &tp, + const ExternalDataFileSlice *externalDataFileSlice) { // Tensor dimensions. ArrayRef tensorDims(tp.dims().data(), tp.dims().size()); if (tp.data_type() == onnx::TensorProto::STRING) { @@ -251,7 +208,7 @@ ElementsAttr onnxTensorProtoToElmAttr(MLIRContext *ctx, auto tensorType = RankedTensorType::get(tensorDims, elmType); return dispatchByBType(btype, [&](auto btype) { using cpptype = CppType; - return createElmAttr(tensorType, tp, externalDataDir); + return createElmAttr(tensorType, tp, externalDataFileSlice); }); } diff --git a/src/Builder/FrontendDialectHelper.hpp b/src/Builder/FrontendDialectHelper.hpp index eb3de87a79..c2ce44483d 100644 --- a/src/Builder/FrontendDialectHelper.hpp +++ b/src/Builder/FrontendDialectHelper.hpp @@ -15,14 +15,26 @@ #pragma once #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/MemoryBuffer.h" #include "onnx/onnx_pb.h" -#include +#include namespace onnx_mlir { +struct ExternalDataFileSlice { + std::shared_ptr file; + uint64_t offset; + uint64_t length; + llvm::StringRef getBufferSlice() const { + return file->getBuffer().substr(offset, length); + } +}; + mlir::ElementsAttr onnxTensorProtoToElmAttr(mlir::MLIRContext *ctx, - const std::string &externalDataDir, const onnx::TensorProto &initializer); + const onnx::TensorProto &initializer, + const ExternalDataFileSlice *externalDataFileSlice = nullptr); } // namespace onnx_mlir diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 0720e7c98c..11ae8a60a2 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/LineIterator.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" #include "include/onnx-mlir/Compiler/OMCompilerTypes.h" #include "src/Builder/FrontendDialectTransformer.hpp" @@ -142,6 +143,14 @@ void replaceAttrRefs(onnx::GraphProto &graph, const AttrMap &attr_map) { // End of copied code from third_party/onnx. // -------------------------------------------------------------------------- // +// Parses unsigned number. +size_t parseOffsetOrLength(const std::string &value) { + char *end = nullptr; + size_t offsetOrLength = strtoull(value.c_str(), &end, 0); + assert(end != value.c_str() && "failed to parse offset or length"); + return offsetOrLength; +} + } // namespace namespace detail { @@ -211,6 +220,66 @@ class FrontendGenImpl { ModelLocalFunctionsMap in_model_functions_; + using ExternalDataFiles = + std::unordered_map>; + + ExternalDataFiles externalDataFiles_; + + const std::shared_ptr &mapExternalDataFile( + const std::string &location) { + auto [iter, inserted] = externalDataFiles_.try_emplace(location, nullptr); + if (inserted) { + StringRef dir = options_.externalDataDir; + SmallVector pathVector(dir.begin(), dir.end()); + llvm::sys::path::append(pathVector, location); + StringRef path(pathVector.data(), pathVector.size()); + // Memory maps file (in most cases) or reads it into memory. + auto bufferOrError = llvm::MemoryBuffer::getFile( + path, /*IsText=*/false, /*RequiresNullTerminator=*/false); + if (std::error_code ec = bufferOrError.getError()) { + llvm::errs() << "Error " << ec.message() << " reading from file " + << path << "\n"; + llvm_unreachable("llvm::MemoryBuffer::getFile failed"); + } + std::unique_ptr buffer = + std::move(bufferOrError.get()); + assert(buffer->getBufferIdentifier() == path && + "buffer identifier is file path"); + iter->second = std::move(buffer); + } + return iter->second; + } + + ExternalDataFileSlice readExternalData(const onnx::TensorProto &tp) { + assert(tp.has_data_location() && + tp.data_location() == onnx::TensorProto::EXTERNAL && + "tensor proto data must be external"); + // MemoryBuffer uses -1 to mean infinity + constexpr uint64_t infiniteLength = -1; + std::string location; + uint64_t offset = 0; + uint64_t length = infiniteLength; + for (const onnx::StringStringEntryProto &entry : tp.external_data()) { + assert(entry.has_key() && "external_data entry must have key"); + assert(entry.has_value() && "external_data entry must have value"); + if (entry.key() == "location") { + location = entry.value(); + } else if (entry.key() == "offset") { + offset = parseOffsetOrLength(entry.value()); + } else if (entry.key() == "length") { + length = parseOffsetOrLength(entry.value()); + } + } + assert(!location.empty() && "missing external data location"); + const std::shared_ptr &file = + mapExternalDataFile(location); + uint64_t fileLength = file->getBufferSize(); + assert(offset <= fileLength && "offset out of range"); + if (length != infiniteLength) + assert(offset + length <= fileLength && "length out of range"); + return {file, offset, length}; + } + Location UnknownLoc() const { return UnknownLoc::get(&context_); } Location ImportLoc(const onnx::NodeProto &node) { @@ -264,8 +333,14 @@ class FrontendGenImpl { } Value ImportTensor(const onnx::TensorProto &tensor) { - mlir::ElementsAttr mlirAttr = - onnxTensorProtoToElmAttr(&context_, options_.externalDataDir, tensor); + mlir::ElementsAttr mlirAttr; + if (tensor.has_data_location() && + tensor.data_location() == onnx::TensorProto::EXTERNAL) { + ExternalDataFileSlice fileSlice = readExternalData(tensor); + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor, &fileSlice); + } else { + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor); + } // Use the tensor name as Location. auto loc = NameLoc::get(builder_.getStringAttr("Initializer_" + tensor.name())); @@ -385,10 +460,16 @@ class FrontendGenImpl { mlirAttr = builder_.getI64ArrayAttr( llvm::ArrayRef(attr.ints().data(), attr.ints().size())); break; - case onnx::AttributeProto::TENSOR: - mlirAttr = onnxTensorProtoToElmAttr( - &context_, options_.externalDataDir, attr.t()); - break; + case onnx::AttributeProto::TENSOR: { + const onnx::TensorProto &tensor = attr.t(); + if (tensor.has_data_location() && + tensor.data_location() == onnx::TensorProto::EXTERNAL) { + ExternalDataFileSlice fileSlice = readExternalData(tensor); + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor, &fileSlice); + } else { + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor); + } + } break; case onnx::AttributeProto::STRINGS: { llvm::SmallVector vectorStringRef; for (const auto &item : attr.strings()) { diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 809bb8ea7d..4c8bb53822 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -65,10 +65,13 @@ struct ElementsAttrBuilder::ElementsProperties { ElementsAttrBuilder::ElementsAttrBuilder(DisposablePool &disposablePool) : disposablePool(disposablePool) {} -ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( - ShapedType type, std::unique_ptr membuf) { +ElementsAttr ElementsAttrBuilder::fromMemoryBuffer(ShapedType type, + std::shared_ptr membuf, uint64_t offset, + uint64_t length) { BType btype = btypeOfMlirType(type.getElementType()); - return createWithDefaultStrides(type, btype, std::move(membuf)); + uint64_t numBytes = type.getNumElements() * bytewidthOfBType(btype); + assert(numBytes == length && "length mismatch"); + return createWithDefaultStrides(type, btype, std::move(membuf), offset); } DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( @@ -908,9 +911,10 @@ ElementsAttr ElementsAttrBuilder::fromRawBytes( } ElementsAttr ElementsAttrBuilder::createWithDefaultStrides(ShapedType type, - BType bufferBType, std::unique_ptr membuf) { + BType bufferBType, std::shared_ptr membuf, + uint64_t offset) { auto strides = getDefaultStrides(type.getShape()); - return create(type, bufferBType, strides, std::move(membuf), 0); + return create(type, bufferBType, strides, std::move(membuf), offset); } ElementsAttr ElementsAttrBuilder::create(ShapedType type, BType bufferBType, diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 06a49128f5..d87e5edd72 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -34,8 +34,9 @@ class ElementsAttrBuilder { // The created instance takes ownership of membuf and will release it when the // instance is disposed by garbage collection, unless it has shared membuf // with other DisposableElementsAttr instances that are longer lived. - mlir::ElementsAttr fromMemoryBuffer( - mlir::ShapedType type, std::unique_ptr membuf); + mlir::ElementsAttr fromMemoryBuffer(mlir::ShapedType type, + std::shared_ptr membuf, uint64_t offset, + uint64_t length); // Wraps elements in a DisposableElementsAttr if it isn't already a // DisposableElementsAttr, provided the underlying DisposablePool is active. @@ -228,7 +229,8 @@ class ElementsAttrBuilder { const Filler &bytesFiller); mlir::ElementsAttr createWithDefaultStrides(mlir::ShapedType type, - BType bufferBType, std::unique_ptr membuf); + BType bufferBType, std::shared_ptr membuf, + uint64_t offset = 0); // Create a DisposableElementsAttr and put it in disposablePool. mlir::ElementsAttr create(mlir::ShapedType type, BType bufferBType, diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index 48fdc4359e..f2940ea6c5 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -127,7 +127,8 @@ Attribute ONNXDialect::parseAttribute( auto shapedTy = type.cast(); if (auto membuf = DisposableElementsAttr::parse(parser, shapedTy)) return OnnxElementsAttrBuilder(type.getContext()) - .fromMemoryBuffer(shapedTy, std::move(membuf)); + .fromMemoryBuffer( + shapedTy, std::move(membuf), 0, membuf->getBufferSize()); else return {}; } From 6b9ae99f3a3bea218a8b0d82c8caa04a1e252e3d Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 11:43:04 -0700 Subject: [PATCH 05/12] keep ElementsAttrBuilder::fromMemoryBuffer() without offset, length Signed-off-by: Soren Lassen --- src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 5 +++++ src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 3 +++ src/Dialect/ONNX/ONNXAttributes.cpp | 3 +-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 4c8bb53822..17a663de7b 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -65,6 +65,11 @@ struct ElementsAttrBuilder::ElementsProperties { ElementsAttrBuilder::ElementsAttrBuilder(DisposablePool &disposablePool) : disposablePool(disposablePool) {} +ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( + ShapedType type, std::unique_ptr membuf) { + return fromMemoryBuffer(type, std::move(membuf), 0, membuf->getBufferSize()); +} + ElementsAttr ElementsAttrBuilder::fromMemoryBuffer(ShapedType type, std::shared_ptr membuf, uint64_t offset, uint64_t length) { diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index d87e5edd72..2e22a8da14 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -34,6 +34,9 @@ class ElementsAttrBuilder { // The created instance takes ownership of membuf and will release it when the // instance is disposed by garbage collection, unless it has shared membuf // with other DisposableElementsAttr instances that are longer lived. + mlir::ElementsAttr fromMemoryBuffer( + mlir::ShapedType type, std::unique_ptr membuf); + mlir::ElementsAttr fromMemoryBuffer(mlir::ShapedType type, std::shared_ptr membuf, uint64_t offset, uint64_t length); diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index f2940ea6c5..48fdc4359e 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -127,8 +127,7 @@ Attribute ONNXDialect::parseAttribute( auto shapedTy = type.cast(); if (auto membuf = DisposableElementsAttr::parse(parser, shapedTy)) return OnnxElementsAttrBuilder(type.getContext()) - .fromMemoryBuffer( - shapedTy, std::move(membuf), 0, membuf->getBufferSize()); + .fromMemoryBuffer(shapedTy, std::move(membuf)); else return {}; } From 785863befd8805c20cc6452914e69d9077ba86a8 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 14:45:45 -0700 Subject: [PATCH 06/12] set attr external_data_dir when any attributes have external data Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectTransformer.cpp | 3 +++ test/mlir/onnx/parse/external_data.json | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 11ae8a60a2..0e2c4f6ba3 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -1424,6 +1424,9 @@ class FrontendGenImpl { auto funcType = importGraph(graph, /*region=*/mainFunc.getBody(), /*op=*/mainFunc.getOperation(), /*useReturn=*/true); mainFunc.setType(funcType); + if (!externalDataFiles_.empty()) + mainFunc->setAttr("external_data_dir", + builder_.getStringAttr(options_.externalDataDir)); // Emit entry point op describing inference function signature. auto entryPoint = ONNXEntryPointOp::create(UnknownLoc(), mainFunc); diff --git a/test/mlir/onnx/parse/external_data.json b/test/mlir/onnx/parse/external_data.json index 30f31bb3fa..76c11682d2 100644 --- a/test/mlir/onnx/parse/external_data.json +++ b/test/mlir/onnx/parse/external_data.json @@ -608,7 +608,7 @@ ] } // CHECK-LABEL: func.func @main_graph -// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { +// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {external_data_dir = "{{.*}}/test/mlir/onnx/parse", input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 0.000000e+00, 1.000000e+00]> : tensor<3xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xui8> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xi8> From c0134b7097e137ffd55e942f4d362766319fbe03 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 15:54:13 -0700 Subject: [PATCH 07/12] add getStridedSize(), remove areStridesSplat() Signed-off-by: Soren Lassen --- .../ONNX/ElementsAttr/DisposableElementsAttr.cpp | 10 ++-------- src/Dialect/ONNX/ElementsAttr/Strides.cpp | 11 +++++++++++ src/Dialect/ONNX/ElementsAttr/Strides.hpp | 9 +++++---- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index 34b5de9e81..e97ce0dc6b 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -79,7 +79,7 @@ void DisposableElementsAttr::dispose() { } bool DisposableElementsAttr::isSplat() const { - return areStridesSplat(getStrides()) && getBuffer()->getBufferSize() != 0; + return getNumBufferElements() == 1; } BType DisposableElementsAttr::getBType() const { return getImpl()->btype; } @@ -127,13 +127,7 @@ unsigned DisposableElementsAttr::getBufferElementBytewidth() const { } int64_t DisposableElementsAttr::getNumBufferElements() const { - int64_t lastPos = 0; - for (auto [dimSize, stride] : llvm::zip(getShape(), getStrides())) { - if (dimSize == 0) - return 0; - lastPos += (dimSize - 1) * stride; - } - return lastPos + 1; + return getStridedSize(getShape(), getStrides()); } ArrayBuffer DisposableElementsAttr::getWideNums() const { diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.cpp b/src/Dialect/ONNX/ElementsAttr/Strides.cpp index 762da77ab2..d525ee9e65 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.cpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.cpp @@ -28,6 +28,17 @@ uint64_t getStridesPosition( return pos; } +int64_t getStridedSize(ArrayRef shape, ArrayRef strides) { + assert(shape.size() == strides.size()); + int64_t lastPos = 0; + for (auto [dimSize, stride] : llvm::zip(shape, strides)) { + if (dimSize == 0) + return 0; + lastPos += (dimSize - 1) * stride; + } + return lastPos + 1; +} + bool areStridesContiguous(ArrayRef shape, ArrayRef strides) { unsigned rank = shape.size(); assert(rank == strides.size()); diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.hpp b/src/Dialect/ONNX/ElementsAttr/Strides.hpp index 9ee4df7bd5..0e988b0eca 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.hpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.hpp @@ -54,10 +54,11 @@ namespace onnx_mlir { uint64_t getStridesPosition( llvm::ArrayRef index, llvm::ArrayRef strides); -// The data is splat (singleton) if strides are all zero. -inline bool areStridesSplat(llvm::ArrayRef strides) { - return llvm::all_of(strides, [](int64_t s) { return s == 0; }); -} +// Size of the underlying linear array that represents shape with the given +// strides. Returns 0 if shape is empty. Returns 1 if strides are splat (all +// zeros) and shape is non-empty. +int64_t getStridedSize( + llvm::ArrayRef shape, llvm::ArrayRef strides); // Returns strides == getDefaultStrides(shape, strides). bool areStridesContiguous( From 1fd36c60108af638dda5527fb367014f92c05ff8 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 16:50:56 -0700 Subject: [PATCH 08/12] pass length to DisposableElementsAttr::create() and perform checks Signed-off-by: Soren Lassen --- .../ElementsAttr/DisposableElementsAttr.cpp | 10 ++++- .../ElementsAttr/DisposableElementsAttr.hpp | 2 +- .../ONNX/ElementsAttr/DisposablePool.cpp | 6 +-- .../ONNX/ElementsAttr/DisposablePool.hpp | 2 +- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 45 ++++++++++++------- .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 4 +- 6 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index e97ce0dc6b..8ceba60add 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -58,12 +58,20 @@ void widenArray( /*static*/ DisposableElementsAttr DisposableElementsAttr::create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, - const Buffer &buffer, uint64_t offset, Transformer transformer) { + const Buffer &buffer, uint64_t offset, uint64_t length, + Transformer transformer) { + assert(offset <= buffer->getBufferSize() && "offset out of range"); + assert(length <= buffer->getBufferSize() - offset && "length out of range"); BType btype = btypeOfMlirType(type.getElementType()); assert((transformer != nullptr || wideBTypeOfBType(bufferBType) == wideBTypeOfBType(btype)) && "buffer wide type mismatch requires transformer"); bool isContiguous = areStridesContiguous(type.getShape(), strides); + int64_t numBufferElements = isContiguous + ? type.getNumElements() + : getStridedSize(type.getShape(), strides); + uint64_t numBytes = numBufferElements * bytewidthOfBType(bufferBType); + assert(numBytes == length && "length mismatch"); DisposableElementsAttr a = Base::get( type.getContext(), type, strides, bufferBType, btype, isContiguous, id); DisposableElementsAttributeStorage &s = *a.getImpl(); diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index 8308e9ea6a..d16a5a7f4e 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -118,7 +118,7 @@ class DisposableElementsAttr // created instance. static DisposableElementsAttr create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, const Buffer &buffer, - uint64_t offset, Transformer transformer); + uint64_t offset, uint64_t length, Transformer transformer); // Clears the buffer payload shared_ptr which decreases the reference count // and, if it reaches zero, frees or closes the underlying MemoryBuffer's diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp index 60b87c0b1b..392bc1cced 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp @@ -25,11 +25,11 @@ DisposablePool::~DisposablePool() {} ElementsAttr DisposablePool::createElementsAttr(ShapedType type, BType bufferBType, ArrayRef strides, const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, - DisposableElementsAttr::Transformer transformer) { + uint64_t length, DisposableElementsAttr::Transformer transformer) { static std::atomic counter{0}; size_t id = ++counter; - auto disposable = DisposableElementsAttr::create( - type, id, bufferBType, strides, buffer, offset, std::move(transformer)); + auto disposable = DisposableElementsAttr::create(type, id, bufferBType, + strides, buffer, offset, length, std::move(transformer)); if (insert(disposable)) { return disposable; } else { diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp index 08fdff8af8..95533df639 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp @@ -58,7 +58,7 @@ class DisposablePool : public mlir::DialectInterface::Base { mlir::ElementsAttr createElementsAttr(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, - mlir::DisposableElementsAttr::Transformer transformer); + uint64_t length, mlir::DisposableElementsAttr::Transformer transformer); // Disposes every DisposableElementsAttr in the pool which is unreachable // (doesn't appear in moduleOp). diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 17a663de7b..0d38694914 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -59,6 +59,7 @@ struct ElementsAttrBuilder::ElementsProperties { SmallVector strides; std::shared_ptr buffer; uint64_t offset; + uint64_t length; const Transformer &transformer; }; @@ -76,7 +77,8 @@ ElementsAttr ElementsAttrBuilder::fromMemoryBuffer(ShapedType type, BType btype = btypeOfMlirType(type.getElementType()); uint64_t numBytes = type.getNumElements() * bytewidthOfBType(btype); assert(numBytes == length && "length mismatch"); - return createWithDefaultStrides(type, btype, std::move(membuf), offset); + return createWithDefaultStrides( + type, btype, std::move(membuf), offset, length); } DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( @@ -87,8 +89,9 @@ DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( if (!disposablePool.isActive()) return nullptr; ElementsProperties props = getElementsProperties(dense); - ElementsAttr created = create(dense.getType(), props.bufferBType, - props.strides, props.buffer, props.offset, props.transformer); + ElementsAttr created = + create(dense.getType(), props.bufferBType, props.strides, props.buffer, + props.offset, props.length, props.transformer); // Check for race condition where disposablePool became inactive since we // checked, in which case it returns a DenseElementsAttr which we don't // want. @@ -337,7 +340,7 @@ ElementsAttr ElementsAttrBuilder::castElementType( : composeTransforms(props.transformer, functionTransformer(wideCaster(oldWideType, newWideType))); return create(newType, props.bufferBType, props.strides, props.buffer, - props.offset, std::move(transformer)); + props.offset, props.length, std::move(transformer)); } namespace { @@ -362,7 +365,7 @@ ElementsAttr ElementsAttrBuilder::transpose( ShapedType transposedType = type.clone(transposedShape); auto transposedStrides = transposeDims(props.strides, perm); return create(transposedType, props.bufferBType, transposedStrides, - props.buffer, props.offset, props.transformer); + props.buffer, props.offset, props.length, props.transformer); } ElementsAttr ElementsAttrBuilder::reshape( @@ -378,7 +381,7 @@ ElementsAttr ElementsAttrBuilder::reshape( if (auto reshapedStrides = reshapeStrides(shape, props.strides, reshapedShape)) return create(reshapedType, props.bufferBType, *reshapedStrides, - props.buffer, props.offset, props.transformer); + props.buffer, props.offset, props.length, props.transformer); auto disp = elms.dyn_cast(); assert(disp && "reshapeStrides() always succeeds for non-Disposable " @@ -408,7 +411,7 @@ ElementsAttr ElementsAttrBuilder::expand( ShapedType expandedType = type.clone(expandedShape); auto expandedStrides = expandStrides(props.strides, expandedShape); return create(expandedType, props.bufferBType, expandedStrides, props.buffer, - props.offset, props.transformer); + props.offset, props.length, props.transformer); } namespace { @@ -842,10 +845,14 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) static Transformer nullTransformer = nullptr; if (auto disposable = elements.dyn_cast()) { ArrayRef strides = disposable.getStrides(); - return {/*.bufferBType=*/disposable.getBufferBType(), + BType bufferBType = disposable.getBufferBType(); + uint64_t length = + disposable.getNumBufferElements() * bytewidthOfBType(bufferBType); + return {/*.bufferBType=*/bufferBType, /*.strides=*/{strides.begin(), strides.end()}, /*.buffer=*/disposable.getBuffer(), /*.offset=*/disposable.getOffset(), + /*.length=*/length, /*.transformer=*/disposable.getTransformer()}; } else if (auto dense = elements.dyn_cast()) { ShapedType type = dense.getType(); @@ -855,10 +862,13 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) } else { strides = getDefaultStrides(type.getShape()); } + std::unique_ptr buffer = getMemoryBuffer(dense); + uint64_t length = buffer->getBufferSize(); return {/*.bufferBType=*/btypeOfMlirType(type.getElementType()), /*.strides=*/{strides.begin(), strides.end()}, - /*.buffer=*/getMemoryBuffer(dense), + /*.buffer=*/std::move(buffer), /*.offset=*/0, + /*.length=*/length, /*.transformer=*/nullTransformer}; } // TODO: consider supporting more ElementsAttr types @@ -890,7 +900,7 @@ ElementsAttr ElementsAttrBuilder::doTransform( ElementsProperties props = getElementsProperties(elms); return create(transformedType, props.bufferBType, props.strides, props.buffer, - props.offset, + props.offset, props.length, composeTransforms(props.transformer, std::move(transformer))); } @@ -902,7 +912,7 @@ ElementsAttr ElementsAttrBuilder::expandAndTransform(ElementsAttr elms, expandStrides(props.strides, expandedTransformedType.getShape()); return create(expandedTransformedType, props.bufferBType, expandedStrides, - props.buffer, props.offset, + props.buffer, props.offset, props.length, composeTransforms(props.transformer, std::move(transformer))); } @@ -912,22 +922,23 @@ ElementsAttr ElementsAttrBuilder::fromRawBytes( std::unique_ptr writeBuffer = llvm::WritableMemoryBuffer::getNewUninitMemBuffer(size); bytesFiller(writeBuffer->getBuffer()); - return createWithDefaultStrides(type, bufferBType, std::move(writeBuffer)); + return createWithDefaultStrides( + type, bufferBType, std::move(writeBuffer), 0, size); } ElementsAttr ElementsAttrBuilder::createWithDefaultStrides(ShapedType type, BType bufferBType, std::shared_ptr membuf, - uint64_t offset) { + uint64_t offset, uint64_t length) { auto strides = getDefaultStrides(type.getShape()); - return create(type, bufferBType, strides, std::move(membuf), offset); + return create(type, bufferBType, strides, std::move(membuf), offset, length); } ElementsAttr ElementsAttrBuilder::create(ShapedType type, BType bufferBType, ArrayRef strides, const std::shared_ptr &buffer, uint64_t offset, - Transformer transformer) { - return disposablePool.createElementsAttr( - type, bufferBType, strides, buffer, offset, std::move(transformer)); + uint64_t length, Transformer transformer) { + return disposablePool.createElementsAttr(type, bufferBType, strides, buffer, + offset, length, std::move(transformer)); } } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 2e22a8da14..38666f9e1f 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -233,13 +233,13 @@ class ElementsAttrBuilder { mlir::ElementsAttr createWithDefaultStrides(mlir::ShapedType type, BType bufferBType, std::shared_ptr membuf, - uint64_t offset = 0); + uint64_t offset, uint64_t length); // Create a DisposableElementsAttr and put it in disposablePool. mlir::ElementsAttr create(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, const std::shared_ptr &buffer, uint64_t offset, - Transformer transformer = nullptr); + uint64_t length, Transformer transformer = nullptr); DisposablePool &disposablePool; }; From 98f02acbb9486895410585fbcf79ac787a346c98 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 18:35:12 -0700 Subject: [PATCH 09/12] update external_data parse lit test to work for Windows paths Signed-off-by: Soren Lassen --- test/mlir/onnx/parse/external_data.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mlir/onnx/parse/external_data.json b/test/mlir/onnx/parse/external_data.json index 76c11682d2..5208091607 100644 --- a/test/mlir/onnx/parse/external_data.json +++ b/test/mlir/onnx/parse/external_data.json @@ -608,7 +608,7 @@ ] } // CHECK-LABEL: func.func @main_graph -// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {external_data_dir = "{{.*}}/test/mlir/onnx/parse", input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { +// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {external_data_dir = "{{.*}}parse", input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 0.000000e+00, 1.000000e+00]> : tensor<3xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xui8> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xi8> From f2e0bdf833ea7c14716db464977a170a5109f3fe Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Wed, 21 Jun 2023 23:23:50 -0700 Subject: [PATCH 10/12] parse dense_disposable Signed-off-by: Soren Lassen --- .../ElementsAttr/DisposableElementsAttr.cpp | 34 +++----------- .../ElementsAttr/DisposableElementsAttr.hpp | 9 ++-- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 47 +++++++++++++++++++ .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 5 ++ src/Dialect/ONNX/ONNXAttributes.cpp | 15 +++--- 5 files changed, 69 insertions(+), 41 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index 8ceba60add..03f9ab1c45 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -176,35 +176,15 @@ bool shouldSwapLEBytes(unsigned elementByteWidth) { } // namespace /*static*/ -std::unique_ptr DisposableElementsAttr::parse( - AsmParser &parser, ShapedType type) { - size_t id = 0; // The parsed id is ignored. - std::string str; +Attribute DisposableElementsAttr::parse(AsmParser &parser, Type type, + function_ref parseElements) { + size_t id = 0; // The parsed id. + ElementsAttr elms; if (parser.parseLess() || parser.parseInteger(id) || parser.parseColon() || - parser.parseString(&str)) + parseElements(id, elms) || parser.parseGreater()) return nullptr; - StringRef hex = str; - std::string bytes; - if (!hex.consume_front("0x") || (hex.size() & 1) || - !llvm::tryGetFromHex(hex, bytes)) { - parser.emitError(parser.getCurrentLocation(), "ill-formed hex string"); - return nullptr; - } - if (bytes.size() != static_cast(getSizeInBytes(type))) { - parser.emitError( - parser.getCurrentLocation(), "data size doesn't match type size"); - return nullptr; - } - if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { - return llvm::MemoryBuffer::getMemBufferCopy(bytes); - } else { - // Reorder bytes from little-endian on big-endian platforms: - std::unique_ptr writeBuffer = - llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); - DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); - return writeBuffer; - } + + return elms; } void DisposableElementsAttr::printWithoutType(AsmPrinter &printer) const { diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index d16a5a7f4e..b41a2ac097 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -21,8 +21,8 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/MemoryBuffer.h" #include @@ -270,11 +270,8 @@ class DisposableElementsAttr static constexpr StringLiteral getMnemonic() { return {"dense_disposable"}; } - // Returns the underlying data as a flat byte array in row-major order. - // If the element type is bool the data holds one byte (with value 0 or 1) per - // bool (contrary to how DenseElementsAttr::getRawData() bit packs bools). - static std::unique_ptr parse( - AsmParser &parser, ShapedType type); + static Attribute parse(AsmParser &parser, Type type, + function_ref parseElements); void printWithoutType(AsmPrinter &printer) const; diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 0d38694914..6eb62da4c2 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -11,6 +11,8 @@ #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" #include "mlir/Dialect/Traits.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Endian.h" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" @@ -66,6 +68,51 @@ struct ElementsAttrBuilder::ElementsProperties { ElementsAttrBuilder::ElementsAttrBuilder(DisposablePool &disposablePool) : disposablePool(disposablePool) {} +namespace { +// Perform byte swap if system endianness is BE and elements are multi-byte. +bool shouldSwapLEBytes(unsigned elementByteWidth) { + return elementByteWidth > 1 && llvm::support::endian::system_endianness() != + llvm::support::endianness::little; +} +} // namespace + +ParseResult ElementsAttrBuilder::parseElements( + AsmParser &parser, ShapedType type, size_t id, ElementsAttr &elms) { + std::string str; + if (parser.parseString(&str)) + return failure(); + if (!parser.parseOptionalColon()) { + uint64_t offset = 0; + uint64_t length = 0; + if (parser.parseInteger(offset) || parser.parseColon() || + parser.parseInteger(length)) + return failure(); + return parser.emitError(parser.getCurrentLocation(), "TODO: implement"); + } else { + StringRef hex = str; + std::string bytes; + if (!hex.consume_front("0x") || (hex.size() & 1) || + !llvm::tryGetFromHex(hex, bytes)) + return parser.emitError( + parser.getCurrentLocation(), "ill-formed hex string"); + if (bytes.size() != static_cast(getSizeInBytes(type))) + return parser.emitError( + parser.getCurrentLocation(), "data size doesn't match type size"); + if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { + elms = + fromMemoryBuffer(type, llvm::MemoryBuffer::getMemBufferCopy(bytes)); + } else { + // Reorder bytes from little-endian on big-endian platforms: + std::unique_ptr writeBuffer = + llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); + DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( + {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); + elms = fromMemoryBuffer(type, std::move(writeBuffer)); + } + return success(); + } +} + ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( ShapedType type, std::unique_ptr membuf) { return fromMemoryBuffer(type, std::move(membuf), 0, membuf->getBufferSize()); diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 38666f9e1f..3dd3ee57ed 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -16,6 +16,8 @@ #include "src/Dialect/ONNX/ElementsAttr/Strides.hpp" #include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp" +#include "mlir/IR/OpImplementation.h" + #include #include @@ -30,6 +32,9 @@ class ElementsAttrBuilder { // in the builder methods. ElementsAttrBuilder(DisposablePool &disposablePool); + mlir::ParseResult parseElements(mlir::AsmParser &parser, + mlir::ShapedType type, size_t id, mlir::ElementsAttr &elms); + // Creates a DisposableElementsAttr instance backed by the data in membuf. // The created instance takes ownership of membuf and will release it when the // instance is disposed by garbage collection, unless it has shared membuf diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index 48fdc4359e..09f306c18f 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -119,17 +119,16 @@ void ONNXDialect::registerAttributes() { Attribute ONNXDialect::parseAttribute( DialectAsmParser &parser, Type type) const { // generatedAttributeParser is generated in ONNXAttributes.cpp.inc + Attribute attr; StringRef attrTag; - if (Attribute attr; - generatedAttributeParser(parser, &attrTag, type, attr).has_value()) + if (generatedAttributeParser(parser, &attrTag, type, attr).has_value()) return attr; if (attrTag == DisposableElementsAttr::getMnemonic()) { - auto shapedTy = type.cast(); - if (auto membuf = DisposableElementsAttr::parse(parser, shapedTy)) - return OnnxElementsAttrBuilder(type.getContext()) - .fromMemoryBuffer(shapedTy, std::move(membuf)); - else - return {}; + return DisposableElementsAttr::parse( + parser, type, [&](size_t id, ElementsAttr &elms) -> ParseResult { + return OnnxElementsAttrBuilder(type.getContext()) + .parseElements(parser, cast(type), id, elms); + }); } parser.emitError(parser.getCurrentLocation()) << "unknown attribute `" << attrTag << "` in dialect `ONNX`"; From 3805e19358c1eabc9b97303f1676b060be3d62c0 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Thu, 22 Jun 2023 06:47:24 -0700 Subject: [PATCH 11/12] add --make_raw flag to onnxExternalizeData.py Signed-off-by: Soren Lassen --- utils/onnxExternalizeData.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/utils/onnxExternalizeData.py b/utils/onnxExternalizeData.py index f801537044..398b7d4b2b 100755 --- a/utils/onnxExternalizeData.py +++ b/utils/onnxExternalizeData.py @@ -6,27 +6,54 @@ # Converts the data in tensors in an onnx model to external data. # Useful tool for constructing external data examples for testing. # +# Call with --make_raw to convert non-raw tensors to raw_data to make them +# eligible to become external data, otherwise +# onnx.save_model(model, path, save_as_external_data=True) +# doesn't convert them to external data. +# ################################################################################ import argparse -import os import onnx +import os +import sys parser = argparse.ArgumentParser() parser.add_argument('model_path', type=str, help="Path to the ONNX model") +parser.add_argument('--size_threshold', type=int, default=0, help="Only convert tensors with byte size no smaller than this") parser.add_argument('--no_all_tensors_to_one_file', action='store_true', help="Save tensors to multiple files") parser.add_argument('--no_convert_attribute', action='store_true', help="Only convert initializer tensors to external data") +parser.add_argument('--make_raw', action='store_true', help="Convert non-raw tensors to raw_data") args = parser.parse_args() +def get_all_tensors(onnx_model_proto): + return onnx.external_data_helper._get_all_tensors(onnx_model_proto) + +def get_initializer_tensors(onnx_model_proto): + return onnx.external_data_helper._get_initializer_tensors(onnx_model_proto) + def main(): filepath = args.model_path basename = os.path.basename(filepath) model = onnx.load_model(filepath) + if args.make_raw: + tensors = get_initializer_tensors(model) if args.no_convert_attribute else get_all_tensors(model) + for tensor in tensors: + if not tensor.HasField("raw_data") and tensor.data_type != onnx.TensorProto.STRING: + arr = onnx.numpy_helper.to_array(tensor) + bytes = arr.tobytes() + if sys.getsizeof(bytes) >= args.size_threshold: + storage_field = onnx.helper.tensor_dtype_to_field(tensor.data_type) + tensor.ClearField(storage_field) + tensor.raw_data = bytes + if sys.byteorder == "big": + # Convert endian from big to little + onnx.numpy_helper.convert_endian(tensor) onnx.save_model(model, args.model_path, save_as_external_data=True, all_tensors_to_one_file=not args.no_all_tensors_to_one_file, location=f"{basename}.ext", - size_threshold=0, + size_threshold=args.size_threshold, convert_attribute=not args.no_convert_attribute ) From 8bcb07e5e3fe7d821c456f5291a1364beb9a40c0 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Thu, 22 Jun 2023 07:59:11 -0700 Subject: [PATCH 12/12] 1st step to move DisposableElementsAttr parsing/printing to ONNXAttributes.cpp Signed-off-by: Soren Lassen --- .../ElementsAttr/DisposableElementsAttr.cpp | 69 --------- .../ElementsAttr/DisposableElementsAttr.hpp | 11 -- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 47 ------ .../ONNX/ElementsAttr/ElementsAttrBuilder.hpp | 5 - src/Dialect/ONNX/ONNXAttributes.cpp | 144 +++++++++++++++++- src/Dialect/ONNX/ONNXOps.cpp | 24 ++- 6 files changed, 159 insertions(+), 141 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index 03f9ab1c45..3c2f0c43d0 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -14,9 +14,6 @@ #include "src/Dialect/ONNX/ElementsAttr/Strides.hpp" #include "src/Support/TypeUtilities.hpp" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Endian.h" - #include #include @@ -167,72 +164,6 @@ DenseElementsAttr DisposableElementsAttr::toDenseElementsAttr() const { return DenseElementsAttr::getFromRawBuffer(getType(), bytes.get()); } -namespace { -// Perform byte swap if system endianness is BE and elements are multi-byte. -bool shouldSwapLEBytes(unsigned elementByteWidth) { - return elementByteWidth > 1 && llvm::support::endian::system_endianness() != - llvm::support::endianness::little; -} -} // namespace - -/*static*/ -Attribute DisposableElementsAttr::parse(AsmParser &parser, Type type, - function_ref parseElements) { - size_t id = 0; // The parsed id. - ElementsAttr elms; - if (parser.parseLess() || parser.parseInteger(id) || parser.parseColon() || - parseElements(id, elms) || parser.parseGreater()) - return nullptr; - - return elms; -} - -void DisposableElementsAttr::printWithoutType(AsmPrinter &printer) const { - // It would be ideal if we could read the printer flags from printer instead - // of constructing them here, because printer may have been constructed with - // an override of elideLargeElementsAttrs which we cannot see here. - // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) - // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. - static OpPrintingFlags printerFlags{}; - printer << getMnemonic() << "<" << getImpl()->id << ":"; - if (!printerFlags.shouldElideElementsAttr(*this)) { - auto rawBytes = getRawBytes(); - SmallVector buffer; - ArrayRef bytes; - if (!shouldSwapLEBytes(getIntOrFloatByteWidth(getElementType()))) { - bytes = rawBytes.get(); - } else { - // Reorder raw bytes to little-endian on big-endian platforms: - buffer.resize_for_overwrite(rawBytes.get().size()); - DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - rawBytes.get(), buffer, getType()); - ArrayRef bufferRef(buffer); - bytes = bufferRef; - } - printer << "\"0x" << llvm::toHex(castArrayRef(bytes)) << "\""; - } else { - printer << "__elided__"; - } - printer << ">"; -} - -void DisposableElementsAttr::printAsDenseElementsAttr( - AsmPrinter &printer) const { - static OpPrintingFlags printerFlags{}; - if (isSplat() || !printerFlags.shouldElideElementsAttr(*this)) { - // Take shortcut by first converting to DenseElementsAttr. - // NOTE: This creates a copy which is never garbage collected. This is not - // only slow but also defeats the garbage collection benefits of - // DisposableElementsAttr, depending on when the printing - // takes place (the print at the end of onnx-mlir-opt in lit tests is ok). - printer.printAttribute(toDenseElementsAttr()); - // TODO: Do the work to print without constructing DenseElementsAttr. - } else { - // In this special case it's easy to avoid conversion to DenseElementsAttr. - printer << "dense<__elided__> : " << getType(); - } -} - void DisposableElementsAttr::readBytesAsWideNums( ArrayRef srcBytes, llvm::MutableArrayRef dst) const { widenArray(getBufferBType(), srcBytes, dst); diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index b41a2ac097..ed3da86290 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -19,9 +19,7 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/MemoryBuffer.h" @@ -268,15 +266,6 @@ class DisposableElementsAttr // Makes deep copy. DenseElementsAttr toDenseElementsAttr() const; - static constexpr StringLiteral getMnemonic() { return {"dense_disposable"}; } - - static Attribute parse(AsmParser &parser, Type type, - function_ref parseElements); - - void printWithoutType(AsmPrinter &printer) const; - - void printAsDenseElementsAttr(AsmPrinter &printer) const; - private: // Widens and transforms bytes into WideNums in accordance with // bufferDType and transformer. diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 6eb62da4c2..0d38694914 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -11,8 +11,6 @@ #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" #include "mlir/Dialect/Traits.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Endian.h" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" @@ -68,51 +66,6 @@ struct ElementsAttrBuilder::ElementsProperties { ElementsAttrBuilder::ElementsAttrBuilder(DisposablePool &disposablePool) : disposablePool(disposablePool) {} -namespace { -// Perform byte swap if system endianness is BE and elements are multi-byte. -bool shouldSwapLEBytes(unsigned elementByteWidth) { - return elementByteWidth > 1 && llvm::support::endian::system_endianness() != - llvm::support::endianness::little; -} -} // namespace - -ParseResult ElementsAttrBuilder::parseElements( - AsmParser &parser, ShapedType type, size_t id, ElementsAttr &elms) { - std::string str; - if (parser.parseString(&str)) - return failure(); - if (!parser.parseOptionalColon()) { - uint64_t offset = 0; - uint64_t length = 0; - if (parser.parseInteger(offset) || parser.parseColon() || - parser.parseInteger(length)) - return failure(); - return parser.emitError(parser.getCurrentLocation(), "TODO: implement"); - } else { - StringRef hex = str; - std::string bytes; - if (!hex.consume_front("0x") || (hex.size() & 1) || - !llvm::tryGetFromHex(hex, bytes)) - return parser.emitError( - parser.getCurrentLocation(), "ill-formed hex string"); - if (bytes.size() != static_cast(getSizeInBytes(type))) - return parser.emitError( - parser.getCurrentLocation(), "data size doesn't match type size"); - if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { - elms = - fromMemoryBuffer(type, llvm::MemoryBuffer::getMemBufferCopy(bytes)); - } else { - // Reorder bytes from little-endian on big-endian platforms: - std::unique_ptr writeBuffer = - llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); - DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); - elms = fromMemoryBuffer(type, std::move(writeBuffer)); - } - return success(); - } -} - ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( ShapedType type, std::unique_ptr membuf) { return fromMemoryBuffer(type, std::move(membuf), 0, membuf->getBufferSize()); diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 3dd3ee57ed..38666f9e1f 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -16,8 +16,6 @@ #include "src/Dialect/ONNX/ElementsAttr/Strides.hpp" #include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp" -#include "mlir/IR/OpImplementation.h" - #include #include @@ -32,9 +30,6 @@ class ElementsAttrBuilder { // in the builder methods. ElementsAttrBuilder(DisposablePool &disposablePool); - mlir::ParseResult parseElements(mlir::AsmParser &parser, - mlir::ShapedType type, size_t id, mlir::ElementsAttr &elms); - // Creates a DisposableElementsAttr instance backed by the data in membuf. // The created instance takes ownership of membuf and will release it when the // instance is disposed by garbage collection, unless it has shared membuf diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index 09f306c18f..55d0704196 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -22,7 +22,9 @@ #include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Endian.h" using namespace mlir; using namespace onnx_mlir; @@ -98,6 +100,136 @@ void ONNXTensorEncodingAttr::print(AsmPrinter &printer) const { printer << "}>"; } +//===----------------------------------------------------------------------===// +// ONNX Attribute: DisposablElementsAttr +//===----------------------------------------------------------------------===// + +namespace { +constexpr StringLiteral getDisposablElementsAttrMnemonic() { + return {"dense_disposable"}; +} + +#if 1 +Attribute parseDisposablElementsAttr(AsmParser &parser, Type type) { + llvm_unreachable("TODO: implement"); +} + +void printDisposablElementsAttr( + AsmPrinter &printer, DisposableElementsAttr disposable) { + llvm_unreachable("TODO: implement"); +} +#else +Attribute parseDisposablElementsAttr(AsmParser &parser, Type type) { + return DisposableElementsAttr::parse( + parser, type, [&](size_t id, ElementsAttr &elms) -> ParseResult { + return OnnxElementsAttrBuilder(type.getContext()) + .parseElements(parser, cast(type), id, elms); + }); +} + +void printDisposablElementsAttr( + AsmPrinter &printer, DisposableElementsAttr disposable) { + disposable.printWithoutType(printer); +} + +static Attribute parse(AsmParser &parser, Type type, + function_ref parseElements); + +void printWithoutType(AsmPrinter &printer) const; + +void printAsDenseElementsAttr(AsmPrinter &printer) const; + +namespace { +// Perform byte swap if system endianness is BE and elements are multi-byte. +bool shouldSwapLEBytes(unsigned elementByteWidth) { + return elementByteWidth > 1 && llvm::support::endian::system_endianness() != + llvm::support::endianness::little; +} +} // namespace + +/*static*/ +Attribute DisposableElementsAttr::parse(AsmParser &parser, Type type, + function_ref parseElements) { + size_t id = 0; // The parsed id. + ElementsAttr elms; + if (parser.parseLess() || parser.parseInteger(id) || parser.parseColon() || + parseElements(id, elms) || parser.parseGreater()) + return nullptr; + + return elms; +} + +void DisposableElementsAttr::printWithoutType(AsmPrinter &printer) const { + // It would be ideal if we could read the printer flags from printer instead + // of constructing them here, because printer may have been constructed with + // an override of elideLargeElementsAttrs which we cannot see here. + // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) + // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. + static OpPrintingFlags printerFlags{}; + printer << getMnemonic() << "<" << getImpl()->id << ":"; + if (!printerFlags.shouldElideElementsAttr(*this)) { + auto rawBytes = getRawBytes(); + SmallVector buffer; + ArrayRef bytes; + if (!shouldSwapLEBytes(getIntOrFloatByteWidth(getElementType()))) { + bytes = rawBytes.get(); + } else { + // Reorder raw bytes to little-endian on big-endian platforms: + buffer.resize_for_overwrite(rawBytes.get().size()); + DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( + rawBytes.get(), buffer, getType()); + ArrayRef bufferRef(buffer); + bytes = bufferRef; + } + printer << "\"0x" << llvm::toHex(castArrayRef(bytes)) << "\""; + } else { + printer << "__elided__"; + } + printer << ">"; +} + +mlir::ParseResult parseElements(mlir::AsmParser &parser, mlir::ShapedType type, + size_t id, mlir::ElementsAttr &elms); + +ParseResult ElementsAttrBuilder::parseElements( + AsmParser &parser, ShapedType type, size_t id, ElementsAttr &elms) { + std::string str; + if (parser.parseString(&str)) + return failure(); + if (!parser.parseOptionalColon()) { + uint64_t offset = 0; + uint64_t length = 0; + if (parser.parseInteger(offset) || parser.parseColon() || + parser.parseInteger(length)) + return failure(); + return parser.emitError(parser.getCurrentLocation(), "TODO: implement"); + } else { + StringRef hex = str; + std::string bytes; + if (!hex.consume_front("0x") || (hex.size() & 1) || + !llvm::tryGetFromHex(hex, bytes)) + return parser.emitError( + parser.getCurrentLocation(), "ill-formed hex string"); + if (bytes.size() != static_cast(getSizeInBytes(type))) + return parser.emitError( + parser.getCurrentLocation(), "data size doesn't match type size"); + if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { + elms = + fromMemoryBuffer(type, llvm::MemoryBuffer::getMemBufferCopy(bytes)); + } else { + // Reorder bytes from little-endian on big-endian platforms: + std::unique_ptr writeBuffer = + llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); + DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( + {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); + elms = fromMemoryBuffer(type, std::move(writeBuffer)); + } + return success(); + } +} +#endif +} // namespace + //===----------------------------------------------------------------------===// // ONNX Attributes: TableGen generated implementation //===----------------------------------------------------------------------===// @@ -123,12 +255,8 @@ Attribute ONNXDialect::parseAttribute( StringRef attrTag; if (generatedAttributeParser(parser, &attrTag, type, attr).has_value()) return attr; - if (attrTag == DisposableElementsAttr::getMnemonic()) { - return DisposableElementsAttr::parse( - parser, type, [&](size_t id, ElementsAttr &elms) -> ParseResult { - return OnnxElementsAttrBuilder(type.getContext()) - .parseElements(parser, cast(type), id, elms); - }); + if (attrTag == getDisposablElementsAttrMnemonic()) { + return parseDisposablElementsAttr(parser, type); } parser.emitError(parser.getCurrentLocation()) << "unknown attribute `" << attrTag << "` in dialect `ONNX`"; @@ -141,6 +269,6 @@ void ONNXDialect::printAttribute( // generatedAttributePrinter is generated in ONNXAttributes.cpp.inc if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto elements = attr.dyn_cast()) - elements.printWithoutType(printer); + if (auto disposable = attr.dyn_cast()) + printDisposablElementsAttr(printer, disposable); } diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 69f0e0076e..a9c9117e42 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -16,6 +16,7 @@ #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" +#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" #include "mlir/Dialect/Traits.h" #include "llvm/ADT/STLExtras.h" @@ -82,12 +83,33 @@ namespace { // Helpers adapted from corresponding methods in mlir/lib/AsmParser/Parser.cpp //===----------------------------------------------------------------------===// +void printAsDenseElementsAttr(AsmPrinter &printer, ElementsAttr elements) { + // It would be ideal if we could read the printer flags from printer instead + // of constructing them here, because printer may have been constructed with + // an override of elideLargeElementsAttrs which we cannot see here. + // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) + // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. + static OpPrintingFlags printerFlags{}; + if (elements.isSplat() || !printerFlags.shouldElideElementsAttr(elements)) { + // Take shortcut by first converting to DenseElementsAttr. + // NOTE: This creates a copy which is never garbage collected. This is not + // only slow but also defeats the garbage collection benefits of + // DisposableElementsAttr, depending on when the printing + // takes place (the print at the end of onnx-mlir-opt in lit tests is ok). + printer.printAttribute(ElementsAttrBuilder::toDenseElementsAttr(elements)); + // TODO: Do the work to print without constructing DenseElementsAttr. + } else { + // In this special case it's easy to avoid conversion to DenseElementsAttr. + printer << "dense<__elided__> : " << elements.getType(); + } +} + // Print DisposableElementsAttr as a DenseElementsAttr, because // DisposableElementsAttr is an internal representation, so we hide it // in this way. void printAttribute(OpAsmPrinter &printer, Attribute attr) { if (auto disposable = attr.dyn_cast()) - disposable.printAsDenseElementsAttr(printer); + printAsDenseElementsAttr(printer, disposable); else printer.printAttribute(attr); }