diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..17ba414d --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,208 @@ +from argparse import Namespace +import contextlib +import io +import os +import pathlib +import sys +import unittest +import tempfile + +import zstandard as zstd +from zstandard import cli + + +@contextlib.contextmanager +def redirect_stdout(): + sys.stdout = io.StringIO() + yield sys.stdout + sys.stdout = sys.__stdout__ + + +class TestCli(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.count = 1 + cls._tmp_dir = tempfile.TemporaryDirectory() + cls.tmp_dir = pathlib.Path(cls._tmp_dir.name) + + def tearDown(self): + for file in self.tmp_dir.iterdir(): + file.unlink() + + @classmethod + def tearDownClass(cls): + cls._tmp_dir.__exit__(None, None, None) + + def test_parser_default(self): + args = cli._parser(["my-file"]) + self.assertEqual( + args, + Namespace( + file="my-file", + outfile=None, + decompress=False, + level=3, + override=False, + threads=0, + rm=False, + ), + ) + + def test_parser(self): + args = cli._parser( + [ + "my-file", + "-d", + "-o", + "out-file", + "-l", + "2", + "--override", + "--threads", + "-1", + "--rm", + ] + ) + self.assertEqual( + args, + Namespace( + file="my-file", + outfile="out-file", + decompress=True, + level=2, + override=True, + threads=-1, + rm=True, + ), + ) + + def make_source_file(self): + data = os.urandom(2048) * 2 + name = self.tmp_dir / "source.dat" + name.write_bytes(data) + return data, name + + def compress(self, data, **kw): + out = io.BytesIO() + if kw: + ctx = zstd.ZstdCompressor(**kw) + else: + ctx = None + with zstd.open(out, "wb", cctx=ctx) as f: + f.write(data) + return out.getvalue() + + def _compress_run(self, main_args, compress_args, outfile=None): + data, name = self.make_source_file() + with redirect_stdout(): + cli.main([str(name.absolute())] + main_args) + + if not outfile: + dest = self.tmp_dir / f"{name.name}.zst" + else: + dest = self.tmp_dir / outfile + self.assertTrue(dest.exists()) + self.assertEqual( + dest.read_bytes(), self.compress(data, **compress_args) + ) + return data, name + + def test_compress(self): + self._compress_run([], {}) + + def test_compress_level(self): + self._compress_run(["-l", "7"], dict(level=7)) + + def test_compress_thread(self): + self._compress_run(["--threads", "-1"], {}) + + def test_compress_rm(self): + _, name = self._compress_run(["--rm"], {}) + self.assertFalse(name.exists()) + + def test_compress_override(self): + _, name = self._compress_run([], {}) + self.assertRaises(FileExistsError, lambda: self._compress_run([], {})) + self._compress_run(["--override"], {}) + + def test_compress_not_exist(self): + name = self.tmp_dir / "unknown_file" + self.assertRaises( + FileNotFoundError, lambda: cli.main([str(name.absolute())]) + ) + + def test_compress_outfile(self): + self._compress_run(["-o", str(self.tmp_dir / "output")], {}, "output") + + def test_compress_same_output(self): + _, name = self.make_source_file() + + def go(dest): + self.assertRaises( + NotImplementedError, + lambda: cli.main([str(name.absolute()), "-o", str(dest)]), + ) + + go(name) + go(name.parent / ".." / name.parent.name / name.name) + + def make_compressed_file(self): + data = os.urandom(2048) * 2 + name = self.tmp_dir / "source.dat.zst" + with zstd.open(name, "wb") as f: + f.write(data) + return data, name + + def _decompress_run(self, main_args, outfile=None): + data, name = self.make_compressed_file() + with redirect_stdout(): + cli.main([str(name.absolute()), "-d"] + main_args) + + if not outfile: + dest = self.tmp_dir / name.stem + else: + dest = self.tmp_dir / outfile + self.assertTrue(dest.exists()) + self.assertEqual(dest.read_bytes(), data) + return data, name + + def test_decompress(self): + self._decompress_run([]) + + def test_decompress_rm(self): + _, name = self._decompress_run(["--rm"]) + self.assertFalse(name.exists()) + + def test_decompress_override(self): + _, name = self._decompress_run([]) + self.assertRaises(FileExistsError, lambda: self._decompress_run([])) + self._decompress_run(["--override"]) + + def test_decompress_not_exist(self): + name = self.tmp_dir / "unknown_file" + self.assertRaises( + FileNotFoundError, lambda: cli.main([str(name.absolute()), "-d"]) + ) + + def test_decompress_outfile(self): + self._decompress_run(["-o", str(self.tmp_dir / "output")], "output") + + def test_decompress_same_output(self): + _, name = self.make_source_file() + + def go(dest, n=name): + self.assertRaises( + NotImplementedError, + lambda: cli.main([str(n.absolute()), "-d", "-o", str(dest)]), + ) + + go(name) + go(name.parent / ".." / name.parent.name / name.name) + + def test_decompress_no_ext(self): + no_ext = self.tmp_dir / "no_ext" + no_ext.touch() + self.assertRaises( + NotImplementedError, + lambda: cli.main([str(no_ext.absolute()), "-d"]), + ) diff --git a/zstandard/__main__.py b/zstandard/__main__.py new file mode 100644 index 00000000..f4db9174 --- /dev/null +++ b/zstandard/__main__.py @@ -0,0 +1,4 @@ +from zstandard.cli import main + +if __name__ == "__main__": + main() diff --git a/zstandard/backend_cffi.py b/zstandard/backend_cffi.py index 43f645f1..61f27e97 100644 --- a/zstandard/backend_cffi.py +++ b/zstandard/backend_cffi.py @@ -1754,7 +1754,7 @@ class ZstdCompressor(object): :param level: Integer compression level. Valid values are all negative integers - through 22. + through 22. Default 3. :param dict_data: A ``ZstdCompressionDict`` to be used to compress with dictionary data. diff --git a/zstandard/cli.py b/zstandard/cli.py new file mode 100644 index 00000000..bb5d6e95 --- /dev/null +++ b/zstandard/cli.py @@ -0,0 +1,129 @@ +import argparse + +CHUNK_SIZE = 8 * 1024 * 1024 + + +def main(args=None): + arguments = _parser(args) + _run(arguments) + + +def _parser(args=None): + parser = argparse.ArgumentParser( + prog="zstandard", + description="Simple cli to use zstandard to de/compress files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("file", help="A filename") + parser.add_argument( + "-o", + "--outfile", + help="Save to this filename", + ) + parser.add_argument( + "-d", + "--decompress", + help="Decompress instead of compressing.", + action="store_true", + ) + parser.add_argument( + "-l", + "--level", + help="Integer compression level. " + "Valid values are all negative integers through 22", + type=int, + default=3, + ) + parser.add_argument( + "--override", + help="Allow overriding existing output files", + action="store_true", + ) + parser.add_argument( + "--threads", + help="Number of threads to use to compress data concurrently. " + "0 disables multi-threaded compression. -1 means all logical CPUs", + type=int, + default=0, + ) + parser.add_argument( + "--rm", + help="Remove source file after successful de/compression", + action="store_true", + ) + return parser.parse_args(args) + + +def _check_args(args): + import pathlib + + file = pathlib.Path(args.file) + if not file.exists() or not file.is_file(): + raise FileNotFoundError( + f"File {args.file} does not exits or is not a file" + ) + if args.outfile is None: + if args.decompress: + outfile = file.with_name(file.stem) + else: + outfile = file.with_name(f"{file.name}.zst") + else: + outfile = pathlib.Path(args.outfile) + if file.resolve() == outfile.resolve(): + raise NotImplementedError( + "Overriding the input file is not supported." + "Please specify another output file" + ) + if outfile.exists() and not args.override: + raise FileExistsError( + f"File {args.outfile} exists. Pass --override to override it" + ) + return file, outfile + + +def _run(args): + import zstandard as zstd + + file, outfile = _check_args(args) + + in_fp, out_fp = None, None + try: + if args.decompress: + in_fp = zstd.open(file, "rb") + out_fp = open(outfile, "wb") + operation = "decompressing" + else: + in_fp = open(file, "rb") + ctx = zstd.ZstdCompressor(level=args.level, threads=args.threads) + out_fp = zstd.open(outfile, "wb", cctx=ctx) + operation = "compressing" + tot = 0 + while True: + data = in_fp.read(CHUNK_SIZE) + if not data: + break + out_fp.write(data) + tot += len(data) + print(f"{operation} .. {tot//(1024*1024)} MB", end="\r") + print(" " * 100, end="\r") + + finally: + if in_fp is not None: + in_fp.close() + if out_fp is not None: + out_fp.close() + + outsize = outfile.stat().st_size + if args.decompress: + print(outfile.name, f": {outsize} bytes") + else: + insize = file.stat().st_size or 1 + print( + args.file, + f": {outsize/insize*100:.2f}% ({insize} => {outsize} bytes)", + outfile.name, + ) + + if args.rm: + file.unlink()