diff --git a/docs/compression/zstd.rst b/docs/compression/zstd.rst index 611b0e83..51f9628d 100644 --- a/docs/compression/zstd.rst +++ b/docs/compression/zstd.rst @@ -7,6 +7,9 @@ Zstd .. autoattribute:: codec_id .. automethod:: encode .. automethod:: decode + .. note:: + If the compressed data does not contain the decompressed size, streaming + decompression will be used. .. automethod:: get_config .. automethod:: from_config diff --git a/docs/release.rst b/docs/release.rst index 30e02a57..ac4f851d 100644 --- a/docs/release.rst +++ b/docs/release.rst @@ -95,6 +95,8 @@ Maintenance Improvements ~~~~~~~~~~~~ +* Add streaming decompression for ZSTD (:issue:`699`) + By :user:`Mark Kittisopikul `. * Raise a custom `UnknownCodecError` when trying to retrieve an unavailable codec. By :user:`Cas Wognum `. diff --git a/numcodecs/tests/test_pyzstd.py b/numcodecs/tests/test_pyzstd.py new file mode 100644 index 00000000..6567fb40 --- /dev/null +++ b/numcodecs/tests/test_pyzstd.py @@ -0,0 +1,60 @@ +# Check Zstd against pyzstd package +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +try: + from numcodecs.zstd import Zstd +except ImportError: # pragma: no cover + pytest.skip("numcodecs.zstd not available", allow_module_level=True) + +if TYPE_CHECKING: # pragma: no cover + import pyzstd +else: + pyzstd = pytest.importorskip("pyzstd") + +test_data = [ + b"Hello World!", + np.arange(113).tobytes(), + np.arange(10, 15).tobytes(), + np.random.randint(3, 50, size=(53,), dtype=np.uint16).tobytes(), +] + + +@pytest.mark.parametrize("input", test_data) +def test_pyzstd_simple(input): + z = Zstd() + assert z.decode(pyzstd.compress(input)) == input + assert pyzstd.decompress(z.encode(input)) == input + + +@pytest.mark.xfail +@pytest.mark.parametrize("input", test_data) +def test_pyzstd_simple_multiple_frames_decode(input): + z = Zstd() + assert z.decode(pyzstd.compress(input) * 2) == input * 2 + + +@pytest.mark.parametrize("input", test_data) +def test_pyzstd_simple_multiple_frames_encode(input): + z = Zstd() + assert pyzstd.decompress(z.encode(input) * 2) == input * 2 + + +@pytest.mark.parametrize("input", test_data) +def test_pyzstd_streaming(input): + pyzstd_c = pyzstd.ZstdCompressor() + pyzstd_d = pyzstd.ZstdDecompressor() + z = Zstd() + + d_bytes = input + pyzstd_c.compress(d_bytes) + c_bytes = pyzstd_c.flush() + assert z.decode(c_bytes) == d_bytes + assert pyzstd_d.decompress(z.encode(d_bytes)) == d_bytes + + # Test multiple streaming frames + assert z.decode(c_bytes * 2) == d_bytes * 2 + assert z.decode(c_bytes * 3) == d_bytes * 3 + assert z.decode(c_bytes * 99) == d_bytes * 99 diff --git a/numcodecs/tests/test_zstd.py b/numcodecs/tests/test_zstd.py index de42d9e1..73433d71 100644 --- a/numcodecs/tests/test_zstd.py +++ b/numcodecs/tests/test_zstd.py @@ -1,4 +1,5 @@ import itertools +import subprocess import numpy as np import pytest @@ -90,3 +91,73 @@ def test_native_functions(): assert Zstd.default_level() == 3 assert Zstd.min_level() == -131072 assert Zstd.max_level() == 22 + + +def test_streaming_decompression(): + # Test input frames with unknown frame content size + codec = Zstd() + + # If the zstd command line interface is available, check the bytes + cli = zstd_cli_available() + if cli: + view_zstd_streaming_bytes() + + # Encode bytes directly that were the result of streaming compression + bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!' + dec = codec.decode(bytes_val) + dec_expected = b'Hello World!' + assert dec == dec_expected + if cli: + assert bytes_val == generate_zstd_streaming_bytes(dec_expected) + assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True) + + # Two consecutive frames given as input + bytes2 = bytes(bytearray(bytes_val * 2)) + dec2 = codec.decode(bytes2) + dec2_expected = b'Hello World!Hello World!' + assert dec2 == dec2_expected + if cli: + assert dec2_expected == generate_zstd_streaming_bytes(bytes2, decompress=True) + + # Single long frame that decompresses to a large output + bytes3 = b'(\xb5/\xfd\x00X$\x02\x00\xa4\x03ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz\x01\x00:\xfc\xdfs\x05\x05L\x00\x00\x08s\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08k\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08c\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08[\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08S\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08K\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08C\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08u\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08m\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08e\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08]\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08U\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08M\x01\x00\xfc\xff9\x10\x02M\x00\x00\x08E\x01\x00\xfc\x7f\x1d\x08\x01' + dec3 = codec.decode(bytes3) + dec3_expected = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz' * 1024 * 32 + assert dec3 == dec3_expected + if cli: + assert bytes3 == generate_zstd_streaming_bytes(dec3_expected) + assert dec3_expected == generate_zstd_streaming_bytes(bytes3, decompress=True) + + # Garbage input results in an error + bytes4 = bytes(bytearray([0, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(RuntimeError, match='Zstd decompression error: invalid input data'): + codec.decode(bytes4) + + +def generate_zstd_streaming_bytes(input: bytes, *, decompress: bool = False) -> bytes: + """ + Use the zstd command line interface to compress or decompress bytes in streaming mode. + """ + if decompress: + args = ["-d"] + else: + args = [] + + p = subprocess.run(["zstd", "--no-check", *args], input=input, capture_output=True) + return p.stdout + + +def view_zstd_streaming_bytes(): + bytes_val = generate_zstd_streaming_bytes(b"Hello world!") + print(f" bytes_val = {bytes_val}") + + bytes3 = generate_zstd_streaming_bytes( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz" * 1024 * 32 + ) + print(f" bytes3 = {bytes3}") + + +def zstd_cli_available() -> bool: + return not subprocess.run( + ["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ).returncode diff --git a/numcodecs/zstd.pyx b/numcodecs/zstd.pyx index 82f2844a..f93da633 100644 --- a/numcodecs/zstd.pyx +++ b/numcodecs/zstd.pyx @@ -13,6 +13,7 @@ from .compat_ext cimport PyBytes_RESIZE, ensure_continguous_memoryview from .compat import ensure_contiguous_ndarray from .abc import Codec +from libc.stdlib cimport malloc, realloc, free cdef extern from "zstd.h": @@ -21,6 +22,23 @@ cdef extern from "zstd.h": struct ZSTD_CCtx_s: pass ctypedef ZSTD_CCtx_s ZSTD_CCtx + + struct ZSTD_DStream_s: + pass + ctypedef ZSTD_DStream_s ZSTD_DStream + + struct ZSTD_inBuffer_s: + const void* src + size_t size + size_t pos + ctypedef ZSTD_inBuffer_s ZSTD_inBuffer + + struct ZSTD_outBuffer_s: + void* dst + size_t size + size_t pos + ctypedef ZSTD_outBuffer_s ZSTD_outBuffer + cdef enum ZSTD_cParameter: ZSTD_c_compressionLevel=100 ZSTD_c_checksumFlag=201 @@ -36,12 +54,20 @@ cdef extern from "zstd.h": size_t dstCapacity, const void* src, size_t srcSize) nogil - size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t compressedSize) nogil + size_t ZSTD_decompressStream(ZSTD_DStream* zds, + ZSTD_outBuffer* output, + ZSTD_inBuffer* input) nogil + + size_t ZSTD_DStreamOutSize() nogil + ZSTD_DStream* ZSTD_createDStream() nogil + size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil + size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil + cdef long ZSTD_CONTENTSIZE_UNKNOWN cdef long ZSTD_CONTENTSIZE_ERROR unsigned long long ZSTD_getFrameContentSize(const void* src, @@ -55,7 +81,7 @@ cdef extern from "zstd.h": unsigned ZSTD_isError(size_t code) nogil - const char* ZSTD_getErrorName(size_t code) + const char* ZSTD_getErrorName(size_t code) nogil VERSION_NUMBER = ZSTD_versionNumber() @@ -157,7 +183,10 @@ def decompress(source, dest=None): source : bytes-like Compressed data. Can be any object supporting the buffer protocol. dest : array-like, optional - Object to decompress into. + Object to decompress into. If the content size is unknown, the + length of dest must match the decompressed size. If the content size + is unknown and dest is not provided, streaming decompression will be + used. Returns ------- @@ -174,6 +203,7 @@ def decompress(source, dest=None): char* dest_ptr size_t source_size, dest_size, decompressed_size size_t nbytes, cbytes, blocksize + size_t dest_nbytes # obtain source memoryview source_mv = ensure_continguous_memoryview(source) @@ -187,9 +217,12 @@ def decompress(source, dest=None): # determine uncompressed size dest_size = ZSTD_getFrameContentSize(source_ptr, source_size) - if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR: + if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR: raise RuntimeError('Zstd decompression error: invalid input data') + if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None: + return stream_decompress(source_pb) + # setup destination buffer if dest is None: # allocate memory @@ -203,6 +236,9 @@ def decompress(source, dest=None): dest_ptr = dest_pb.buf dest_nbytes = dest_pb.len + if dest_size == ZSTD_CONTENTSIZE_UNKNOWN: + dest_size = dest_nbytes + # validate output buffer if dest_nbytes < dest_size: raise ValueError('destination buffer too small; expected at least %s, ' @@ -225,6 +261,97 @@ def decompress(source, dest=None): return dest +cdef stream_decompress(const Py_buffer* source_pb): + """Decompress data of unknown size + + Parameters + ---------- + source : Py_buffer + Compressed data buffer + + Returns + ------- + dest : bytes + Object containing decompressed data. + """ + + cdef: + const char *source_ptr + void *dest_ptr + void *new_dst + size_t source_size, dest_size, decompressed_size + size_t DEST_GROWTH_SIZE, status + ZSTD_inBuffer input + ZSTD_outBuffer output + ZSTD_DStream *zds + + # Recommended size for output buffer, guaranteed to flush at least + # one completely block in all circumstances + DEST_GROWTH_SIZE = ZSTD_DStreamOutSize(); + + source_ptr = source_pb.buf + source_size = source_pb.len + + # unknown content size, guess it is twice the size as the source + dest_size = source_size * 2 + + if dest_size < DEST_GROWTH_SIZE: + # minimum dest_size is DEST_GROWTH_SIZE + dest_size = DEST_GROWTH_SIZE + + dest_ptr = malloc(dest_size) + zds = ZSTD_createDStream() + + try: + + with nogil: + + status = ZSTD_initDStream(zds) + if ZSTD_isError(status): + error = ZSTD_getErrorName(status) + ZSTD_freeDStream(zds); + raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error) + + input = ZSTD_inBuffer(source_ptr, source_size, 0) + output = ZSTD_outBuffer(dest_ptr, dest_size, 0) + + # Initialize to 1 to force a loop iteration + status = 1 + while(status > 0 or input.pos < input.size): + # Possible returned values of ZSTD_decompressStream: + # 0: frame is completely decoded and fully flushed + # error (<0) + # >0: suggested next input size + status = ZSTD_decompressStream(zds, &output, &input) + + if ZSTD_isError(status): + error = ZSTD_getErrorName(status) + raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error) + + # There is more to decompress, grow the buffer + if status > 0 and output.pos == output.size: + new_size = output.size + DEST_GROWTH_SIZE + + if new_size < output.size or new_size < DEST_GROWTH_SIZE: + raise RuntimeError('Zstd stream decompression error: output buffer overflow') + + new_dst = realloc(output.dst, new_size) + + if new_dst == NULL: + # output.dst freed in finally block + raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer') + + output.dst = new_dst + output.size = new_size + + # Copy the output to a bytes object + dest = PyBytes_FromStringAndSize(output.dst, output.pos) + + finally: + ZSTD_freeDStream(zds) + free(output.dst) + + return dest class Zstd(Codec): """Codec providing compression using Zstandard. diff --git a/pyproject.toml b/pyproject.toml index e7d8ff69..387603f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ test = [ "coverage", "pytest", "pytest-cov", + "pyzstd" ] test_extras = [ "importlib_metadata",