diff --git a/numcodecs/zstd.pyx b/numcodecs/zstd.pyx index efd12fa2..f1a84011 100644 --- a/numcodecs/zstd.pyx +++ b/numcodecs/zstd.pyx @@ -8,12 +8,12 @@ from cpython.buffer cimport PyBUF_ANY_CONTIGUOUS, PyBUF_WRITEABLE from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING - from .compat_ext cimport Buffer from .compat_ext import Buffer from .compat import ensure_contiguous_ndarray from .abc import Codec +from libc.stdlib cimport malloc, realloc, free cdef extern from "zstd.h": @@ -22,6 +22,23 @@ cdef extern from "zstd.h": struct ZSTD_CCtx_s: pass ctypedef ZSTD_CCtx_s ZSTD_CCtx + + struct ZSTD_DStream_s: + pass + ctypedef ZSTD_DStream_s ZSTD_DStream + + struct ZSTD_inBuffer_s: + const void* src + size_t size + size_t pos + ctypedef ZSTD_inBuffer_s ZSTD_inBuffer + + struct ZSTD_outBuffer_s: + void* dst + size_t size + size_t pos + ctypedef ZSTD_outBuffer_s ZSTD_outBuffer + cdef enum ZSTD_cParameter: ZSTD_c_compressionLevel=100 ZSTD_c_checksumFlag=201 @@ -37,12 +54,20 @@ cdef extern from "zstd.h": size_t dstCapacity, const void* src, size_t srcSize) nogil - size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t compressedSize) nogil + size_t ZSTD_decompressStream(ZSTD_DStream* zds, + ZSTD_outBuffer* output, + ZSTD_inBuffer* input) nogil + + size_t ZSTD_DStreamOutSize() nogil + ZSTD_DStream* ZSTD_createDStream() nogil + size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil + size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil + cdef long ZSTD_CONTENTSIZE_UNKNOWN cdef long ZSTD_CONTENTSIZE_ERROR unsigned long long ZSTD_getFrameContentSize(const void* src, @@ -56,7 +81,7 @@ cdef extern from "zstd.h": unsigned ZSTD_isError(size_t code) nogil - const char* ZSTD_getErrorName(size_t code) + const char* ZSTD_getErrorName(size_t code) nogil VERSION_NUMBER = ZSTD_versionNumber() @@ -156,7 +181,8 @@ def decompress(source, dest=None): source : bytes-like Compressed data. Can be any object supporting the buffer protocol. dest : array-like, optional - Object to decompress into. + Object to decompress into. If the content size is unknown, the + length of dest must match the decompressed size. Returns ------- @@ -180,9 +206,12 @@ def decompress(source, dest=None): # determine uncompressed size dest_size = ZSTD_getFrameContentSize(source_ptr, source_size) - if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR: + if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR: raise RuntimeError('Zstd decompression error: invalid input data') + if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None: + return stream_decompress(source_buffer) + # setup destination buffer if dest is None: # allocate memory @@ -192,6 +221,8 @@ def decompress(source, dest=None): arr = ensure_contiguous_ndarray(dest) dest_buffer = Buffer(arr, PyBUF_ANY_CONTIGUOUS | PyBUF_WRITEABLE) dest_ptr = dest_buffer.ptr + if dest_size == ZSTD_CONTENTSIZE_UNKNOWN: + dest_size = dest_buffer.nbytes if dest_buffer.nbytes < dest_size: raise ValueError('destination buffer too small; expected at least %s, ' 'got %s' % (dest_size, dest_buffer.nbytes)) @@ -217,6 +248,98 @@ def decompress(source, dest=None): return dest +cdef stream_decompress(Buffer source_buffer): + """Decompress data of unknown size + + Parameters + ---------- + source : Buffer + Compressed data buffer + + Returns + ------- + dest : bytes + Object containing decompressed data. + """ + + cdef: + char *source_ptr + void *dest_ptr + void *new_dst + Buffer dest_buffer = None + size_t source_size, dest_size, decompressed_size + size_t DEST_GROWTH_SIZE, status + ZSTD_inBuffer input + ZSTD_outBuffer output + ZSTD_DStream *zds + + # Recommended size for output buffer, guaranteed to flush at least + # one completely block in all circumstances + DEST_GROWTH_SIZE = ZSTD_DStreamOutSize(); + + source_ptr = source_buffer.ptr + source_size = source_buffer.nbytes + + # unknown content size, guess it is twice the size as the source + dest_size = source_size * 2 + + if dest_size < DEST_GROWTH_SIZE: + # minimum dest_size is DEST_GROWTH_SIZE + dest_size = DEST_GROWTH_SIZE + + dest_ptr = malloc(dest_size) + zds = ZSTD_createDStream() + + try: + + with nogil: + + status = ZSTD_initDStream(zds) + if ZSTD_isError(status): + error = ZSTD_getErrorName(status) + ZSTD_freeDStream(zds); + raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error) + + input = ZSTD_inBuffer(source_ptr, source_size, 0) + output = ZSTD_outBuffer(dest_ptr, dest_size, 0) + + # Initialize to 1 to force a loop iteration + status = 1 + while(status > 0 or input.pos < input.size): + # Possible returned values of ZSTD_decompressStream: + # 0: frame is completely decoded and fully flushed + # error (<0) + # >0: suggested next input size + status = ZSTD_decompressStream(zds, &output, &input) + + if ZSTD_isError(status): + error = ZSTD_getErrorName(status) + raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error) + + # There is more to decompress, grow the buffer + if status > 0 and output.pos == output.size: + new_size = output.size + DEST_GROWTH_SIZE + + if new_size < output.size or new_size < DEST_GROWTH_SIZE: + raise RuntimeError('Zstd stream decompression error: output buffer overflow') + + new_dst = realloc(output.dst, new_size) + + if new_dst == NULL: + # output.dst freed in finally block + raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer') + + output.dst = new_dst + output.size = new_size + + # Copy the output to a bytes object + dest = PyBytes_FromStringAndSize(output.dst, output.pos) + + finally: + ZSTD_freeDStream(zds) + free(output.dst) + + return dest class Zstd(Codec): """Codec providing compression using Zstandard.