diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 4f4d56a320..34d6fdc3bc 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -17,8 +17,6 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Endian.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" #include "llvm/Support/SwapByteOrder.h" #include "src/Dialect/ONNX/ElementsAttr/Arrays.hpp" @@ -33,47 +31,6 @@ namespace onnx_mlir { namespace { -// Parses unsigned number. -size_t parseOffsetOrLength(const std::string &value) { - char *end = nullptr; - size_t offsetOrLength = strtoull(value.c_str(), &end, 0); - assert(end != value.c_str() && "failed to parse offset or length"); - return offsetOrLength; -} - -// Reads external data from file location specified in tensor proto. -// The data is little endian encoded. -// See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md -std::unique_ptr readExternalData_LE( - const std::string &externalDataDir, const onnx::TensorProto &tp) { - std::string location; - uint64_t offset = 0; - uint64_t length = -1; // MemoryBuffer uses -1 to mean infinity - for (const onnx::StringStringEntryProto &entry : tp.external_data()) { - assert(entry.has_key() && "external_data entry must have key"); - assert(entry.has_value() && "external_data entry must have value"); - if (entry.key() == "location") { - location = entry.value(); - } else if (entry.key() == "offset") { - offset = parseOffsetOrLength(entry.value()); - } else if (entry.key() == "length") { - length = parseOffsetOrLength(entry.value()); - } - } - assert(!location.empty() && "missing external data location"); - SmallVector path(externalDataDir.begin(), externalDataDir.end()); - llvm::sys::path::append(path, location); - auto bufferOrError = llvm::MemoryBuffer::getFileSlice( - path, length, offset, /*IsVolatile=*/false); - if (std::error_code ec = bufferOrError.getError()) { - std::string pathStr(path.data(), path.size()); - llvm::errs() << "Error " << ec.message() << " reading from file " << pathStr - << ", offset=" << offset << ", length=" << length << "\n"; - llvm_unreachable("llvm::MemoryBuffer::getFileSlice failed"); - } - return std::move(bufferOrError.get()); -} - template struct TransformValueToONNXData { static const google::protobuf::RepeatedField &data( @@ -156,15 +113,15 @@ T swappedBytes(T x) { template ElementsAttr createElementsAttrFromMemoryBuffer_LE( - RankedTensorType tensorType, std::unique_ptr membuf) { + RankedTensorType tensorType, const ExternalDataFileSlice &fileSlice) { MLIRContext *ctx = tensorType.getContext(); assert(tensorType.getElementType() == toMlirType(ctx)); if constexpr (shouldSwapLEBytes) { - ArrayRef array = asArrayRef(membuf->getBuffer()); + ArrayRef array = asArrayRef(fileSlice.getBufferSlice()); return createElmAttrFromArray(tensorType, array, swappedBytes); } else { return OnnxElementsAttrBuilder(ctx).fromMemoryBuffer( - tensorType, std::move(membuf)); + tensorType, fileSlice.file, fileSlice.offset, fileSlice.length); } } @@ -203,11 +160,12 @@ ElementsAttr createElmAttrFromProtoData(RankedTensorType tensorType, // Returns ElementsAttr with tp's data. template ElementsAttr createElmAttr(RankedTensorType tensorType, - const onnx::TensorProto &tp, const std::string &externalDataDir) { + const onnx::TensorProto &tp, + const ExternalDataFileSlice *externalDataFileSlice) { if (tp.has_data_location() && tp.data_location() == onnx::TensorProto::EXTERNAL) { return createElementsAttrFromMemoryBuffer_LE( - tensorType, readExternalData_LE(externalDataDir, tp)); + tensorType, *externalDataFileSlice); } if (tp.has_raw_data()) { return createElmAttrFromRawBytes_LE( @@ -236,7 +194,8 @@ ElementsAttr createStringElmAttr( } // namespace ElementsAttr onnxTensorProtoToElmAttr(MLIRContext *ctx, - const std::string &externalDataDir, const onnx::TensorProto &tp) { + const onnx::TensorProto &tp, + const ExternalDataFileSlice *externalDataFileSlice) { // Tensor dimensions. ArrayRef tensorDims(tp.dims().data(), tp.dims().size()); if (tp.data_type() == onnx::TensorProto::STRING) { @@ -249,7 +208,7 @@ ElementsAttr onnxTensorProtoToElmAttr(MLIRContext *ctx, auto tensorType = RankedTensorType::get(tensorDims, elmType); return dispatchByBType(btype, [&](auto btype) { using cpptype = CppType; - return createElmAttr(tensorType, tp, externalDataDir); + return createElmAttr(tensorType, tp, externalDataFileSlice); }); } diff --git a/src/Builder/FrontendDialectHelper.hpp b/src/Builder/FrontendDialectHelper.hpp index eb3de87a79..c2ce44483d 100644 --- a/src/Builder/FrontendDialectHelper.hpp +++ b/src/Builder/FrontendDialectHelper.hpp @@ -15,14 +15,26 @@ #pragma once #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/MemoryBuffer.h" #include "onnx/onnx_pb.h" -#include +#include namespace onnx_mlir { +struct ExternalDataFileSlice { + std::shared_ptr file; + uint64_t offset; + uint64_t length; + llvm::StringRef getBufferSlice() const { + return file->getBuffer().substr(offset, length); + } +}; + mlir::ElementsAttr onnxTensorProtoToElmAttr(mlir::MLIRContext *ctx, - const std::string &externalDataDir, const onnx::TensorProto &initializer); + const onnx::TensorProto &initializer, + const ExternalDataFileSlice *externalDataFileSlice = nullptr); } // namespace onnx_mlir diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 0720e7c98c..0e2c4f6ba3 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/LineIterator.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" #include "include/onnx-mlir/Compiler/OMCompilerTypes.h" #include "src/Builder/FrontendDialectTransformer.hpp" @@ -142,6 +143,14 @@ void replaceAttrRefs(onnx::GraphProto &graph, const AttrMap &attr_map) { // End of copied code from third_party/onnx. // -------------------------------------------------------------------------- // +// Parses unsigned number. +size_t parseOffsetOrLength(const std::string &value) { + char *end = nullptr; + size_t offsetOrLength = strtoull(value.c_str(), &end, 0); + assert(end != value.c_str() && "failed to parse offset or length"); + return offsetOrLength; +} + } // namespace namespace detail { @@ -211,6 +220,66 @@ class FrontendGenImpl { ModelLocalFunctionsMap in_model_functions_; + using ExternalDataFiles = + std::unordered_map>; + + ExternalDataFiles externalDataFiles_; + + const std::shared_ptr &mapExternalDataFile( + const std::string &location) { + auto [iter, inserted] = externalDataFiles_.try_emplace(location, nullptr); + if (inserted) { + StringRef dir = options_.externalDataDir; + SmallVector pathVector(dir.begin(), dir.end()); + llvm::sys::path::append(pathVector, location); + StringRef path(pathVector.data(), pathVector.size()); + // Memory maps file (in most cases) or reads it into memory. + auto bufferOrError = llvm::MemoryBuffer::getFile( + path, /*IsText=*/false, /*RequiresNullTerminator=*/false); + if (std::error_code ec = bufferOrError.getError()) { + llvm::errs() << "Error " << ec.message() << " reading from file " + << path << "\n"; + llvm_unreachable("llvm::MemoryBuffer::getFile failed"); + } + std::unique_ptr buffer = + std::move(bufferOrError.get()); + assert(buffer->getBufferIdentifier() == path && + "buffer identifier is file path"); + iter->second = std::move(buffer); + } + return iter->second; + } + + ExternalDataFileSlice readExternalData(const onnx::TensorProto &tp) { + assert(tp.has_data_location() && + tp.data_location() == onnx::TensorProto::EXTERNAL && + "tensor proto data must be external"); + // MemoryBuffer uses -1 to mean infinity + constexpr uint64_t infiniteLength = -1; + std::string location; + uint64_t offset = 0; + uint64_t length = infiniteLength; + for (const onnx::StringStringEntryProto &entry : tp.external_data()) { + assert(entry.has_key() && "external_data entry must have key"); + assert(entry.has_value() && "external_data entry must have value"); + if (entry.key() == "location") { + location = entry.value(); + } else if (entry.key() == "offset") { + offset = parseOffsetOrLength(entry.value()); + } else if (entry.key() == "length") { + length = parseOffsetOrLength(entry.value()); + } + } + assert(!location.empty() && "missing external data location"); + const std::shared_ptr &file = + mapExternalDataFile(location); + uint64_t fileLength = file->getBufferSize(); + assert(offset <= fileLength && "offset out of range"); + if (length != infiniteLength) + assert(offset + length <= fileLength && "length out of range"); + return {file, offset, length}; + } + Location UnknownLoc() const { return UnknownLoc::get(&context_); } Location ImportLoc(const onnx::NodeProto &node) { @@ -264,8 +333,14 @@ class FrontendGenImpl { } Value ImportTensor(const onnx::TensorProto &tensor) { - mlir::ElementsAttr mlirAttr = - onnxTensorProtoToElmAttr(&context_, options_.externalDataDir, tensor); + mlir::ElementsAttr mlirAttr; + if (tensor.has_data_location() && + tensor.data_location() == onnx::TensorProto::EXTERNAL) { + ExternalDataFileSlice fileSlice = readExternalData(tensor); + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor, &fileSlice); + } else { + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor); + } // Use the tensor name as Location. auto loc = NameLoc::get(builder_.getStringAttr("Initializer_" + tensor.name())); @@ -385,10 +460,16 @@ class FrontendGenImpl { mlirAttr = builder_.getI64ArrayAttr( llvm::ArrayRef(attr.ints().data(), attr.ints().size())); break; - case onnx::AttributeProto::TENSOR: - mlirAttr = onnxTensorProtoToElmAttr( - &context_, options_.externalDataDir, attr.t()); - break; + case onnx::AttributeProto::TENSOR: { + const onnx::TensorProto &tensor = attr.t(); + if (tensor.has_data_location() && + tensor.data_location() == onnx::TensorProto::EXTERNAL) { + ExternalDataFileSlice fileSlice = readExternalData(tensor); + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor, &fileSlice); + } else { + mlirAttr = onnxTensorProtoToElmAttr(&context_, tensor); + } + } break; case onnx::AttributeProto::STRINGS: { llvm::SmallVector vectorStringRef; for (const auto &item : attr.strings()) { @@ -1343,6 +1424,9 @@ class FrontendGenImpl { auto funcType = importGraph(graph, /*region=*/mainFunc.getBody(), /*op=*/mainFunc.getOperation(), /*useReturn=*/true); mainFunc.setType(funcType); + if (!externalDataFiles_.empty()) + mainFunc->setAttr("external_data_dir", + builder_.getStringAttr(options_.externalDataDir)); // Emit entry point op describing inference function signature. auto entryPoint = ONNXEntryPointOp::create(UnknownLoc(), mainFunc); diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp index ae5ffaa589..3c2f0c43d0 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.cpp @@ -14,9 +14,6 @@ #include "src/Dialect/ONNX/ElementsAttr/Strides.hpp" #include "src/Support/TypeUtilities.hpp" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Endian.h" - #include #include @@ -58,16 +55,25 @@ void widenArray( /*static*/ DisposableElementsAttr DisposableElementsAttr::create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, - const Buffer &buffer, Transformer transformer) { + const Buffer &buffer, uint64_t offset, uint64_t length, + Transformer transformer) { + assert(offset <= buffer->getBufferSize() && "offset out of range"); + assert(length <= buffer->getBufferSize() - offset && "length out of range"); BType btype = btypeOfMlirType(type.getElementType()); assert((transformer != nullptr || wideBTypeOfBType(bufferBType) == wideBTypeOfBType(btype)) && "buffer wide type mismatch requires transformer"); bool isContiguous = areStridesContiguous(type.getShape(), strides); + int64_t numBufferElements = isContiguous + ? type.getNumElements() + : getStridedSize(type.getShape(), strides); + uint64_t numBytes = numBufferElements * bytewidthOfBType(bufferBType); + assert(numBytes == length && "length mismatch"); DisposableElementsAttr a = Base::get( type.getContext(), type, strides, bufferBType, btype, isContiguous, id); DisposableElementsAttributeStorage &s = *a.getImpl(); s.buffer = buffer; + s.offset = offset; s.transformer = std::move(transformer); return a; } @@ -78,7 +84,7 @@ void DisposableElementsAttr::dispose() { } bool DisposableElementsAttr::isSplat() const { - return areStridesSplat(getStrides()) && getBuffer()->getBufferSize() != 0; + return getNumBufferElements() == 1; } BType DisposableElementsAttr::getBType() const { return getImpl()->btype; } @@ -98,6 +104,8 @@ auto DisposableElementsAttr::getBuffer() const -> const Buffer & { return getImpl()->buffer; } +uint64_t DisposableElementsAttr::getOffset() const { return getImpl()->offset; } + auto DisposableElementsAttr::getTransformer() const -> const Transformer & { assert(!isDisposed()); return getImpl()->transformer; @@ -124,7 +132,7 @@ unsigned DisposableElementsAttr::getBufferElementBytewidth() const { } int64_t DisposableElementsAttr::getNumBufferElements() const { - return getBuffer()->getBufferSize() / getBufferElementBytewidth(); + return getStridedSize(getShape(), getStrides()); } ArrayBuffer DisposableElementsAttr::getWideNums() const { @@ -156,92 +164,6 @@ DenseElementsAttr DisposableElementsAttr::toDenseElementsAttr() const { return DenseElementsAttr::getFromRawBuffer(getType(), bytes.get()); } -namespace { -// Perform byte swap if system endianness is BE and elements are multi-byte. -bool shouldSwapLEBytes(unsigned elementByteWidth) { - return elementByteWidth > 1 && llvm::support::endian::system_endianness() != - llvm::support::endianness::little; -} -} // namespace - -/*static*/ -std::unique_ptr DisposableElementsAttr::parse( - AsmParser &parser, ShapedType type) { - size_t id = 0; // The parsed id is ignored. - std::string str; - if (parser.parseLess() || parser.parseInteger(id) || parser.parseColon() || - parser.parseString(&str)) - return nullptr; - StringRef hex = str; - std::string bytes; - if (!hex.consume_front("0x") || (hex.size() & 1) || - !llvm::tryGetFromHex(hex, bytes)) { - parser.emitError(parser.getCurrentLocation(), "ill-formed hex string"); - return nullptr; - } - if (bytes.size() != static_cast(getSizeInBytes(type))) { - parser.emitError( - parser.getCurrentLocation(), "data size doesn't match type size"); - return nullptr; - } - if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { - return llvm::MemoryBuffer::getMemBufferCopy(bytes); - } else { - // Reorder bytes from little-endian on big-endian platforms: - std::unique_ptr writeBuffer = - llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); - DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); - return writeBuffer; - } -} - -void DisposableElementsAttr::printWithoutType(AsmPrinter &printer) const { - // It would be ideal if we could read the printer flags from printer instead - // of constructing them here, because printer may have been constructed with - // an override of elideLargeElementsAttrs which we cannot see here. - // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) - // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. - static OpPrintingFlags printerFlags{}; - printer << getMnemonic() << "<" << getImpl()->id << ":"; - if (!printerFlags.shouldElideElementsAttr(*this)) { - auto rawBytes = getRawBytes(); - SmallVector buffer; - ArrayRef bytes; - if (!shouldSwapLEBytes(getIntOrFloatByteWidth(getElementType()))) { - bytes = rawBytes.get(); - } else { - // Reorder raw bytes to little-endian on big-endian platforms: - buffer.resize_for_overwrite(rawBytes.get().size()); - DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - rawBytes.get(), buffer, getType()); - ArrayRef bufferRef(buffer); - bytes = bufferRef; - } - printer << "\"0x" << llvm::toHex(castArrayRef(bytes)) << "\""; - } else { - printer << "__elided__"; - } - printer << ">"; -} - -void DisposableElementsAttr::printAsDenseElementsAttr( - AsmPrinter &printer) const { - static OpPrintingFlags printerFlags{}; - if (isSplat() || !printerFlags.shouldElideElementsAttr(*this)) { - // Take shortcut by first converting to DenseElementsAttr. - // NOTE: This creates a copy which is never garbage collected. This is not - // only slow but also defeats the garbage collection benefits of - // DisposableElementsAttr, depending on when the printing - // takes place (the print at the end of onnx-mlir-opt in lit tests is ok). - printer.printAttribute(toDenseElementsAttr()); - // TODO: Do the work to print without constructing DenseElementsAttr. - } else { - // In this special case it's easy to avoid conversion to DenseElementsAttr. - printer << "dense<__elided__> : " << getType(); - } -} - void DisposableElementsAttr::readBytesAsWideNums( ArrayRef srcBytes, llvm::MutableArrayRef dst) const { widenArray(getBufferBType(), srcBytes, dst); @@ -250,7 +172,8 @@ void DisposableElementsAttr::readBytesAsWideNums( } ArrayRef DisposableElementsAttr::getBufferBytes() const { - return asArrayRef(getBuffer()->getBuffer()); + size_t numBytes = getNumBufferElements() * getBufferElementBytewidth(); + return asArrayRef(getBuffer()->getBuffer().substr(getOffset(), numBytes)); } ArrayBuffer DisposableElementsAttr::getBufferAsWideNums() const { diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index 2f8912f577..ed3da86290 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -19,10 +19,8 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/StringRef.h" #include "llvm/Support/MemoryBuffer.h" #include @@ -118,7 +116,7 @@ class DisposableElementsAttr // created instance. static DisposableElementsAttr create(ShapedType type, size_t id, BType bufferBType, ArrayRef strides, const Buffer &buffer, - Transformer transformer); + uint64_t offset, uint64_t length, Transformer transformer); // Clears the buffer payload shared_ptr which decreases the reference count // and, if it reaches zero, frees or closes the underlying MemoryBuffer's @@ -154,6 +152,8 @@ class DisposableElementsAttr const Buffer &getBuffer() const; + uint64_t getOffset() const; + const Transformer &getTransformer() const; bool isContiguous() const; @@ -266,18 +266,6 @@ class DisposableElementsAttr // Makes deep copy. DenseElementsAttr toDenseElementsAttr() const; - static constexpr StringLiteral getMnemonic() { return {"dense_disposable"}; } - - // Returns the underlying data as a flat byte array in row-major order. - // If the element type is bool the data holds one byte (with value 0 or 1) per - // bool (contrary to how DenseElementsAttr::getRawData() bit packs bools). - static std::unique_ptr parse( - AsmParser &parser, ShapedType type); - - void printWithoutType(AsmPrinter &printer) const; - - void printAsDenseElementsAttr(AsmPrinter &printer) const; - private: // Widens and transforms bytes into WideNums in accordance with // bufferDType and transformer. diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp index 23f4c5f9d5..11fc23805b 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttributeStorage.hpp @@ -37,7 +37,7 @@ struct DisposableElementsAttributeStorage : public AttributeStorage { onnx_mlir::BType bufferBType, onnx_mlir::BType btype, bool isContiguous, size_t id) : type(type), strides(strides), bufferBType(bufferBType), btype(btype), - isContiguous(isContiguous), id(id) {} + isContiguous(isContiguous), id(id), offset(0) {} // Equality and hashKey are engineered to defeat the storage uniquer. // We don't want uniqueing because we can't compare transformers for equality @@ -103,6 +103,8 @@ struct DisposableElementsAttributeStorage : public AttributeStorage { // file closed) when no one points to it anymore. Buffer buffer; + uint64_t offset; + // Reads the buffer elements to WideNums corresponding to type's // element type. Is null if data is not transformed. // In this case the buffer data type and the type's element type must promote diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp index 9342b9b286..392bc1cced 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp @@ -24,12 +24,12 @@ DisposablePool::~DisposablePool() {} ElementsAttr DisposablePool::createElementsAttr(ShapedType type, BType bufferBType, ArrayRef strides, - const mlir::DisposableElementsAttr::Buffer &buffer, - DisposableElementsAttr::Transformer transformer) { + const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, + uint64_t length, DisposableElementsAttr::Transformer transformer) { static std::atomic counter{0}; size_t id = ++counter; - auto disposable = DisposableElementsAttr::create( - type, id, bufferBType, strides, buffer, std::move(transformer)); + auto disposable = DisposableElementsAttr::create(type, id, bufferBType, + strides, buffer, offset, length, std::move(transformer)); if (insert(disposable)) { return disposable; } else { diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp index 9a538c06ef..95533df639 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp @@ -57,8 +57,8 @@ class DisposablePool : public mlir::DialectInterface::Base { // otherwise returns conversion to DenseElementsAttr. mlir::ElementsAttr createElementsAttr(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, - const mlir::DisposableElementsAttr::Buffer &buffer, - mlir::DisposableElementsAttr::Transformer transformer); + const mlir::DisposableElementsAttr::Buffer &buffer, uint64_t offset, + uint64_t length, mlir::DisposableElementsAttr::Transformer transformer); // Disposes every DisposableElementsAttr in the pool which is unreachable // (doesn't appear in moduleOp). diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index e0e23f03df..0d38694914 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -58,6 +58,8 @@ struct ElementsAttrBuilder::ElementsProperties { BType bufferBType; SmallVector strides; std::shared_ptr buffer; + uint64_t offset; + uint64_t length; const Transformer &transformer; }; @@ -66,8 +68,17 @@ ElementsAttrBuilder::ElementsAttrBuilder(DisposablePool &disposablePool) ElementsAttr ElementsAttrBuilder::fromMemoryBuffer( ShapedType type, std::unique_ptr membuf) { + return fromMemoryBuffer(type, std::move(membuf), 0, membuf->getBufferSize()); +} + +ElementsAttr ElementsAttrBuilder::fromMemoryBuffer(ShapedType type, + std::shared_ptr membuf, uint64_t offset, + uint64_t length) { BType btype = btypeOfMlirType(type.getElementType()); - return createWithDefaultStrides(type, btype, std::move(membuf)); + uint64_t numBytes = type.getNumElements() * bytewidthOfBType(btype); + assert(numBytes == length && "length mismatch"); + return createWithDefaultStrides( + type, btype, std::move(membuf), offset, length); } DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( @@ -78,8 +89,9 @@ DisposableElementsAttr ElementsAttrBuilder::toDisposableElementsAttr( if (!disposablePool.isActive()) return nullptr; ElementsProperties props = getElementsProperties(dense); - ElementsAttr created = create(dense.getType(), props.bufferBType, - props.strides, props.buffer, props.transformer); + ElementsAttr created = + create(dense.getType(), props.bufferBType, props.strides, props.buffer, + props.offset, props.length, props.transformer); // Check for race condition where disposablePool became inactive since we // checked, in which case it returns a DenseElementsAttr which we don't // want. @@ -328,7 +340,7 @@ ElementsAttr ElementsAttrBuilder::castElementType( : composeTransforms(props.transformer, functionTransformer(wideCaster(oldWideType, newWideType))); return create(newType, props.bufferBType, props.strides, props.buffer, - std::move(transformer)); + props.offset, props.length, std::move(transformer)); } namespace { @@ -353,7 +365,7 @@ ElementsAttr ElementsAttrBuilder::transpose( ShapedType transposedType = type.clone(transposedShape); auto transposedStrides = transposeDims(props.strides, perm); return create(transposedType, props.bufferBType, transposedStrides, - props.buffer, props.transformer); + props.buffer, props.offset, props.length, props.transformer); } ElementsAttr ElementsAttrBuilder::reshape( @@ -369,7 +381,7 @@ ElementsAttr ElementsAttrBuilder::reshape( if (auto reshapedStrides = reshapeStrides(shape, props.strides, reshapedShape)) return create(reshapedType, props.bufferBType, *reshapedStrides, - props.buffer, props.transformer); + props.buffer, props.offset, props.length, props.transformer); auto disp = elms.dyn_cast(); assert(disp && "reshapeStrides() always succeeds for non-Disposable " @@ -399,7 +411,7 @@ ElementsAttr ElementsAttrBuilder::expand( ShapedType expandedType = type.clone(expandedShape); auto expandedStrides = expandStrides(props.strides, expandedShape); return create(expandedType, props.bufferBType, expandedStrides, props.buffer, - props.transformer); + props.offset, props.length, props.transformer); } namespace { @@ -833,9 +845,14 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) static Transformer nullTransformer = nullptr; if (auto disposable = elements.dyn_cast()) { ArrayRef strides = disposable.getStrides(); - return {/*.bufferBType=*/disposable.getBufferBType(), + BType bufferBType = disposable.getBufferBType(); + uint64_t length = + disposable.getNumBufferElements() * bytewidthOfBType(bufferBType); + return {/*.bufferBType=*/bufferBType, /*.strides=*/{strides.begin(), strides.end()}, /*.buffer=*/disposable.getBuffer(), + /*.offset=*/disposable.getOffset(), + /*.length=*/length, /*.transformer=*/disposable.getTransformer()}; } else if (auto dense = elements.dyn_cast()) { ShapedType type = dense.getType(); @@ -845,9 +862,13 @@ auto ElementsAttrBuilder::getElementsProperties(ElementsAttr elements) } else { strides = getDefaultStrides(type.getShape()); } + std::unique_ptr buffer = getMemoryBuffer(dense); + uint64_t length = buffer->getBufferSize(); return {/*.bufferBType=*/btypeOfMlirType(type.getElementType()), /*.strides=*/{strides.begin(), strides.end()}, - /*.buffer=*/getMemoryBuffer(dense), + /*.buffer=*/std::move(buffer), + /*.offset=*/0, + /*.length=*/length, /*.transformer=*/nullTransformer}; } // TODO: consider supporting more ElementsAttr types @@ -879,6 +900,7 @@ ElementsAttr ElementsAttrBuilder::doTransform( ElementsProperties props = getElementsProperties(elms); return create(transformedType, props.bufferBType, props.strides, props.buffer, + props.offset, props.length, composeTransforms(props.transformer, std::move(transformer))); } @@ -890,7 +912,7 @@ ElementsAttr ElementsAttrBuilder::expandAndTransform(ElementsAttr elms, expandStrides(props.strides, expandedTransformedType.getShape()); return create(expandedTransformedType, props.bufferBType, expandedStrides, - props.buffer, + props.buffer, props.offset, props.length, composeTransforms(props.transformer, std::move(transformer))); } @@ -900,21 +922,23 @@ ElementsAttr ElementsAttrBuilder::fromRawBytes( std::unique_ptr writeBuffer = llvm::WritableMemoryBuffer::getNewUninitMemBuffer(size); bytesFiller(writeBuffer->getBuffer()); - return createWithDefaultStrides(type, bufferBType, std::move(writeBuffer)); + return createWithDefaultStrides( + type, bufferBType, std::move(writeBuffer), 0, size); } ElementsAttr ElementsAttrBuilder::createWithDefaultStrides(ShapedType type, - BType bufferBType, std::unique_ptr membuf) { + BType bufferBType, std::shared_ptr membuf, + uint64_t offset, uint64_t length) { auto strides = getDefaultStrides(type.getShape()); - return create(type, bufferBType, strides, std::move(membuf)); + return create(type, bufferBType, strides, std::move(membuf), offset, length); } ElementsAttr ElementsAttrBuilder::create(ShapedType type, BType bufferBType, ArrayRef strides, - const std::shared_ptr &buffer, - Transformer transformer) { - return disposablePool.createElementsAttr( - type, bufferBType, strides, buffer, std::move(transformer)); + const std::shared_ptr &buffer, uint64_t offset, + uint64_t length, Transformer transformer) { + return disposablePool.createElementsAttr(type, bufferBType, strides, buffer, + offset, length, std::move(transformer)); } } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index 1dd7ea3dba..38666f9e1f 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -37,6 +37,10 @@ class ElementsAttrBuilder { mlir::ElementsAttr fromMemoryBuffer( mlir::ShapedType type, std::unique_ptr membuf); + mlir::ElementsAttr fromMemoryBuffer(mlir::ShapedType type, + std::shared_ptr membuf, uint64_t offset, + uint64_t length); + // Wraps elements in a DisposableElementsAttr if it isn't already a // DisposableElementsAttr, provided the underlying DisposablePool is active. // If elements is DenseElementsAttr the wrapper points into elements' raw @@ -228,13 +232,14 @@ class ElementsAttrBuilder { const Filler &bytesFiller); mlir::ElementsAttr createWithDefaultStrides(mlir::ShapedType type, - BType bufferBType, std::unique_ptr membuf); + BType bufferBType, std::shared_ptr membuf, + uint64_t offset, uint64_t length); // Create a DisposableElementsAttr and put it in disposablePool. mlir::ElementsAttr create(mlir::ShapedType type, BType bufferBType, llvm::ArrayRef strides, - const std::shared_ptr &buffer, - Transformer transformer = nullptr); + const std::shared_ptr &buffer, uint64_t offset, + uint64_t length, Transformer transformer = nullptr); DisposablePool &disposablePool; }; diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.cpp b/src/Dialect/ONNX/ElementsAttr/Strides.cpp index 762da77ab2..d525ee9e65 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.cpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.cpp @@ -28,6 +28,17 @@ uint64_t getStridesPosition( return pos; } +int64_t getStridedSize(ArrayRef shape, ArrayRef strides) { + assert(shape.size() == strides.size()); + int64_t lastPos = 0; + for (auto [dimSize, stride] : llvm::zip(shape, strides)) { + if (dimSize == 0) + return 0; + lastPos += (dimSize - 1) * stride; + } + return lastPos + 1; +} + bool areStridesContiguous(ArrayRef shape, ArrayRef strides) { unsigned rank = shape.size(); assert(rank == strides.size()); diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.hpp b/src/Dialect/ONNX/ElementsAttr/Strides.hpp index 9ee4df7bd5..0e988b0eca 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.hpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.hpp @@ -54,10 +54,11 @@ namespace onnx_mlir { uint64_t getStridesPosition( llvm::ArrayRef index, llvm::ArrayRef strides); -// The data is splat (singleton) if strides are all zero. -inline bool areStridesSplat(llvm::ArrayRef strides) { - return llvm::all_of(strides, [](int64_t s) { return s == 0; }); -} +// Size of the underlying linear array that represents shape with the given +// strides. Returns 0 if shape is empty. Returns 1 if strides are splat (all +// zeros) and shape is non-empty. +int64_t getStridedSize( + llvm::ArrayRef shape, llvm::ArrayRef strides); // Returns strides == getDefaultStrides(shape, strides). bool areStridesContiguous( diff --git a/src/Dialect/ONNX/ONNXAttributes.cpp b/src/Dialect/ONNX/ONNXAttributes.cpp index 48fdc4359e..55d0704196 100644 --- a/src/Dialect/ONNX/ONNXAttributes.cpp +++ b/src/Dialect/ONNX/ONNXAttributes.cpp @@ -22,7 +22,9 @@ #include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Endian.h" using namespace mlir; using namespace onnx_mlir; @@ -98,6 +100,136 @@ void ONNXTensorEncodingAttr::print(AsmPrinter &printer) const { printer << "}>"; } +//===----------------------------------------------------------------------===// +// ONNX Attribute: DisposablElementsAttr +//===----------------------------------------------------------------------===// + +namespace { +constexpr StringLiteral getDisposablElementsAttrMnemonic() { + return {"dense_disposable"}; +} + +#if 1 +Attribute parseDisposablElementsAttr(AsmParser &parser, Type type) { + llvm_unreachable("TODO: implement"); +} + +void printDisposablElementsAttr( + AsmPrinter &printer, DisposableElementsAttr disposable) { + llvm_unreachable("TODO: implement"); +} +#else +Attribute parseDisposablElementsAttr(AsmParser &parser, Type type) { + return DisposableElementsAttr::parse( + parser, type, [&](size_t id, ElementsAttr &elms) -> ParseResult { + return OnnxElementsAttrBuilder(type.getContext()) + .parseElements(parser, cast(type), id, elms); + }); +} + +void printDisposablElementsAttr( + AsmPrinter &printer, DisposableElementsAttr disposable) { + disposable.printWithoutType(printer); +} + +static Attribute parse(AsmParser &parser, Type type, + function_ref parseElements); + +void printWithoutType(AsmPrinter &printer) const; + +void printAsDenseElementsAttr(AsmPrinter &printer) const; + +namespace { +// Perform byte swap if system endianness is BE and elements are multi-byte. +bool shouldSwapLEBytes(unsigned elementByteWidth) { + return elementByteWidth > 1 && llvm::support::endian::system_endianness() != + llvm::support::endianness::little; +} +} // namespace + +/*static*/ +Attribute DisposableElementsAttr::parse(AsmParser &parser, Type type, + function_ref parseElements) { + size_t id = 0; // The parsed id. + ElementsAttr elms; + if (parser.parseLess() || parser.parseInteger(id) || parser.parseColon() || + parseElements(id, elms) || parser.parseGreater()) + return nullptr; + + return elms; +} + +void DisposableElementsAttr::printWithoutType(AsmPrinter &printer) const { + // It would be ideal if we could read the printer flags from printer instead + // of constructing them here, because printer may have been constructed with + // an override of elideLargeElementsAttrs which we cannot see here. + // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) + // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. + static OpPrintingFlags printerFlags{}; + printer << getMnemonic() << "<" << getImpl()->id << ":"; + if (!printerFlags.shouldElideElementsAttr(*this)) { + auto rawBytes = getRawBytes(); + SmallVector buffer; + ArrayRef bytes; + if (!shouldSwapLEBytes(getIntOrFloatByteWidth(getElementType()))) { + bytes = rawBytes.get(); + } else { + // Reorder raw bytes to little-endian on big-endian platforms: + buffer.resize_for_overwrite(rawBytes.get().size()); + DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( + rawBytes.get(), buffer, getType()); + ArrayRef bufferRef(buffer); + bytes = bufferRef; + } + printer << "\"0x" << llvm::toHex(castArrayRef(bytes)) << "\""; + } else { + printer << "__elided__"; + } + printer << ">"; +} + +mlir::ParseResult parseElements(mlir::AsmParser &parser, mlir::ShapedType type, + size_t id, mlir::ElementsAttr &elms); + +ParseResult ElementsAttrBuilder::parseElements( + AsmParser &parser, ShapedType type, size_t id, ElementsAttr &elms) { + std::string str; + if (parser.parseString(&str)) + return failure(); + if (!parser.parseOptionalColon()) { + uint64_t offset = 0; + uint64_t length = 0; + if (parser.parseInteger(offset) || parser.parseColon() || + parser.parseInteger(length)) + return failure(); + return parser.emitError(parser.getCurrentLocation(), "TODO: implement"); + } else { + StringRef hex = str; + std::string bytes; + if (!hex.consume_front("0x") || (hex.size() & 1) || + !llvm::tryGetFromHex(hex, bytes)) + return parser.emitError( + parser.getCurrentLocation(), "ill-formed hex string"); + if (bytes.size() != static_cast(getSizeInBytes(type))) + return parser.emitError( + parser.getCurrentLocation(), "data size doesn't match type size"); + if (!shouldSwapLEBytes(getIntOrFloatByteWidth(type.getElementType()))) { + elms = + fromMemoryBuffer(type, llvm::MemoryBuffer::getMemBufferCopy(bytes)); + } else { + // Reorder bytes from little-endian on big-endian platforms: + std::unique_ptr writeBuffer = + llvm::WritableMemoryBuffer::getNewUninitMemBuffer(bytes.size()); + DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( + {bytes.data(), bytes.size()}, writeBuffer->getBuffer(), type); + elms = fromMemoryBuffer(type, std::move(writeBuffer)); + } + return success(); + } +} +#endif +} // namespace + //===----------------------------------------------------------------------===// // ONNX Attributes: TableGen generated implementation //===----------------------------------------------------------------------===// @@ -119,17 +251,12 @@ void ONNXDialect::registerAttributes() { Attribute ONNXDialect::parseAttribute( DialectAsmParser &parser, Type type) const { // generatedAttributeParser is generated in ONNXAttributes.cpp.inc + Attribute attr; StringRef attrTag; - if (Attribute attr; - generatedAttributeParser(parser, &attrTag, type, attr).has_value()) + if (generatedAttributeParser(parser, &attrTag, type, attr).has_value()) return attr; - if (attrTag == DisposableElementsAttr::getMnemonic()) { - auto shapedTy = type.cast(); - if (auto membuf = DisposableElementsAttr::parse(parser, shapedTy)) - return OnnxElementsAttrBuilder(type.getContext()) - .fromMemoryBuffer(shapedTy, std::move(membuf)); - else - return {}; + if (attrTag == getDisposablElementsAttrMnemonic()) { + return parseDisposablElementsAttr(parser, type); } parser.emitError(parser.getCurrentLocation()) << "unknown attribute `" << attrTag << "` in dialect `ONNX`"; @@ -142,6 +269,6 @@ void ONNXDialect::printAttribute( // generatedAttributePrinter is generated in ONNXAttributes.cpp.inc if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto elements = attr.dyn_cast()) - elements.printWithoutType(printer); + if (auto disposable = attr.dyn_cast()) + printDisposablElementsAttr(printer, disposable); } diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 69f0e0076e..a9c9117e42 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -16,6 +16,7 @@ #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" +#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" #include "mlir/Dialect/Traits.h" #include "llvm/ADT/STLExtras.h" @@ -82,12 +83,33 @@ namespace { // Helpers adapted from corresponding methods in mlir/lib/AsmParser/Parser.cpp //===----------------------------------------------------------------------===// +void printAsDenseElementsAttr(AsmPrinter &printer, ElementsAttr elements) { + // It would be ideal if we could read the printer flags from printer instead + // of constructing them here, because printer may have been constructed with + // an override of elideLargeElementsAttrs which we cannot see here. + // Oh well, at least OpPrintingFlags().shouldElideElementsAttr(ElementsAttr) + // lets us respect the --mlir-elide-elementsattrs-if-larger command line flag. + static OpPrintingFlags printerFlags{}; + if (elements.isSplat() || !printerFlags.shouldElideElementsAttr(elements)) { + // Take shortcut by first converting to DenseElementsAttr. + // NOTE: This creates a copy which is never garbage collected. This is not + // only slow but also defeats the garbage collection benefits of + // DisposableElementsAttr, depending on when the printing + // takes place (the print at the end of onnx-mlir-opt in lit tests is ok). + printer.printAttribute(ElementsAttrBuilder::toDenseElementsAttr(elements)); + // TODO: Do the work to print without constructing DenseElementsAttr. + } else { + // In this special case it's easy to avoid conversion to DenseElementsAttr. + printer << "dense<__elided__> : " << elements.getType(); + } +} + // Print DisposableElementsAttr as a DenseElementsAttr, because // DisposableElementsAttr is an internal representation, so we hide it // in this way. void printAttribute(OpAsmPrinter &printer, Attribute attr) { if (auto disposable = attr.dyn_cast()) - disposable.printAsDenseElementsAttr(printer); + printAsDenseElementsAttr(printer, disposable); else printer.printAttribute(attr); } diff --git a/test/mlir/onnx/parse/external_data.json b/test/mlir/onnx/parse/external_data.json index 30f31bb3fa..5208091607 100644 --- a/test/mlir/onnx/parse/external_data.json +++ b/test/mlir/onnx/parse/external_data.json @@ -608,7 +608,7 @@ ] } // CHECK-LABEL: func.func @main_graph -// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { +// CHECK-SAME: () -> (tensor<3xf32>, tensor<3xui8>, tensor<3xi8>, tensor<3xui16>, tensor<3xi16>, tensor<3xi32>, tensor<3xi64>, tensor<3xi1>, tensor<3xf16>, tensor<3xf64>, tensor<3xui32>, tensor<3xui64>) attributes {external_data_dir = "{{.*}}parse", input_names = [], output_names = ["output0", "output1", "output2", "output3", "output4", "output5", "output6", "output7", "output8", "output9", "output10", "output11"]} { // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 0.000000e+00, 1.000000e+00]> : tensor<3xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xui8> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 0, 1]> : tensor<3xi8> diff --git a/utils/onnxExternalizeData.py b/utils/onnxExternalizeData.py index f801537044..398b7d4b2b 100755 --- a/utils/onnxExternalizeData.py +++ b/utils/onnxExternalizeData.py @@ -6,27 +6,54 @@ # Converts the data in tensors in an onnx model to external data. # Useful tool for constructing external data examples for testing. # +# Call with --make_raw to convert non-raw tensors to raw_data to make them +# eligible to become external data, otherwise +# onnx.save_model(model, path, save_as_external_data=True) +# doesn't convert them to external data. +# ################################################################################ import argparse -import os import onnx +import os +import sys parser = argparse.ArgumentParser() parser.add_argument('model_path', type=str, help="Path to the ONNX model") +parser.add_argument('--size_threshold', type=int, default=0, help="Only convert tensors with byte size no smaller than this") parser.add_argument('--no_all_tensors_to_one_file', action='store_true', help="Save tensors to multiple files") parser.add_argument('--no_convert_attribute', action='store_true', help="Only convert initializer tensors to external data") +parser.add_argument('--make_raw', action='store_true', help="Convert non-raw tensors to raw_data") args = parser.parse_args() +def get_all_tensors(onnx_model_proto): + return onnx.external_data_helper._get_all_tensors(onnx_model_proto) + +def get_initializer_tensors(onnx_model_proto): + return onnx.external_data_helper._get_initializer_tensors(onnx_model_proto) + def main(): filepath = args.model_path basename = os.path.basename(filepath) model = onnx.load_model(filepath) + if args.make_raw: + tensors = get_initializer_tensors(model) if args.no_convert_attribute else get_all_tensors(model) + for tensor in tensors: + if not tensor.HasField("raw_data") and tensor.data_type != onnx.TensorProto.STRING: + arr = onnx.numpy_helper.to_array(tensor) + bytes = arr.tobytes() + if sys.getsizeof(bytes) >= args.size_threshold: + storage_field = onnx.helper.tensor_dtype_to_field(tensor.data_type) + tensor.ClearField(storage_field) + tensor.raw_data = bytes + if sys.byteorder == "big": + # Convert endian from big to little + onnx.numpy_helper.convert_endian(tensor) onnx.save_model(model, args.model_path, save_as_external_data=True, all_tensors_to_one_file=not args.no_all_tensors_to_one_file, location=f"{basename}.ext", - size_threshold=0, + size_threshold=args.size_threshold, convert_attribute=not args.no_convert_attribute )