Skip to content

Add simple cli to de/compress files #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from argparse import Namespace
import contextlib
import io
import os
import pathlib
import sys
import unittest
import tempfile

import zstandard as zstd
from zstandard import cli


@contextlib.contextmanager
def redirect_stdout():
sys.stdout = io.StringIO()
yield sys.stdout
sys.stdout = sys.__stdout__


class TestCli(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.count = 1
cls._tmp_dir = tempfile.TemporaryDirectory()
cls.tmp_dir = pathlib.Path(cls._tmp_dir.name)

def tearDown(self):
for file in self.tmp_dir.iterdir():
file.unlink()

@classmethod
def tearDownClass(cls):
cls._tmp_dir.__exit__(None, None, None)

def test_parser_default(self):
args = cli._parser(["my-file"])
self.assertEqual(
args,
Namespace(
file="my-file",
outfile=None,
decompress=False,
level=3,
override=False,
threads=0,
rm=False,
),
)

def test_parser(self):
args = cli._parser(
[
"my-file",
"-d",
"-o",
"out-file",
"-l",
"2",
"--override",
"--threads",
"-1",
"--rm",
]
)
self.assertEqual(
args,
Namespace(
file="my-file",
outfile="out-file",
decompress=True,
level=2,
override=True,
threads=-1,
rm=True,
),
)

def make_source_file(self):
data = os.urandom(2048) * 2
name = self.tmp_dir / "source.dat"
name.write_bytes(data)
return data, name

def compress(self, data, **kw):
out = io.BytesIO()
if kw:
ctx = zstd.ZstdCompressor(**kw)
else:
ctx = None
with zstd.open(out, "wb", cctx=ctx) as f:
f.write(data)
return out.getvalue()

def _compress_run(self, main_args, compress_args, outfile=None):
data, name = self.make_source_file()
with redirect_stdout():
cli.main([str(name.absolute())] + main_args)

if not outfile:
dest = self.tmp_dir / f"{name.name}.zst"
else:
dest = self.tmp_dir / outfile
self.assertTrue(dest.exists())
self.assertEqual(
dest.read_bytes(), self.compress(data, **compress_args)
)
return data, name

def test_compress(self):
self._compress_run([], {})

def test_compress_level(self):
self._compress_run(["-l", "7"], dict(level=7))

def test_compress_thread(self):
self._compress_run(["--threads", "-1"], {})

def test_compress_rm(self):
_, name = self._compress_run(["--rm"], {})
self.assertFalse(name.exists())

def test_compress_override(self):
_, name = self._compress_run([], {})
self.assertRaises(FileExistsError, lambda: self._compress_run([], {}))
self._compress_run(["--override"], {})

def test_compress_not_exist(self):
name = self.tmp_dir / "unknown_file"
self.assertRaises(
FileNotFoundError, lambda: cli.main([str(name.absolute())])
)

def test_compress_outfile(self):
self._compress_run(["-o", str(self.tmp_dir / "output")], {}, "output")

def test_compress_same_output(self):
_, name = self.make_source_file()

def go(dest):
self.assertRaises(
NotImplementedError,
lambda: cli.main([str(name.absolute()), "-o", str(dest)]),
)

go(name)
go(name.parent / ".." / name.parent.name / name.name)

def make_compressed_file(self):
data = os.urandom(2048) * 2
name = self.tmp_dir / "source.dat.zst"
with zstd.open(name, "wb") as f:
f.write(data)
return data, name

def _decompress_run(self, main_args, outfile=None):
data, name = self.make_compressed_file()
with redirect_stdout():
cli.main([str(name.absolute()), "-d"] + main_args)

if not outfile:
dest = self.tmp_dir / name.stem
else:
dest = self.tmp_dir / outfile
self.assertTrue(dest.exists())
self.assertEqual(dest.read_bytes(), data)
return data, name

def test_decompress(self):
self._decompress_run([])

def test_decompress_rm(self):
_, name = self._decompress_run(["--rm"])
self.assertFalse(name.exists())

def test_decompress_override(self):
_, name = self._decompress_run([])
self.assertRaises(FileExistsError, lambda: self._decompress_run([]))
self._decompress_run(["--override"])

def test_decompress_not_exist(self):
name = self.tmp_dir / "unknown_file"
self.assertRaises(
FileNotFoundError, lambda: cli.main([str(name.absolute()), "-d"])
)

def test_decompress_outfile(self):
self._decompress_run(["-o", str(self.tmp_dir / "output")], "output")

def test_decompress_same_output(self):
_, name = self.make_source_file()

def go(dest, n=name):
self.assertRaises(
NotImplementedError,
lambda: cli.main([str(n.absolute()), "-d", "-o", str(dest)]),
)

go(name)
go(name.parent / ".." / name.parent.name / name.name)

def test_decompress_no_ext(self):
no_ext = self.tmp_dir / "no_ext"
no_ext.touch()
self.assertRaises(
NotImplementedError,
lambda: cli.main([str(no_ext.absolute()), "-d"]),
)
4 changes: 4 additions & 0 deletions zstandard/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from zstandard.cli import main

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion zstandard/backend_cffi.py
Original file line number Diff line number Diff line change
@@ -1754,7 +1754,7 @@ class ZstdCompressor(object):

:param level:
Integer compression level. Valid values are all negative integers
through 22.
through 22. Default 3.
:param dict_data:
A ``ZstdCompressionDict`` to be used to compress with dictionary
data.
129 changes: 129 additions & 0 deletions zstandard/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse

CHUNK_SIZE = 8 * 1024 * 1024


def main(args=None):
arguments = _parser(args)
_run(arguments)


def _parser(args=None):
parser = argparse.ArgumentParser(
prog="zstandard",
description="Simple cli to use zstandard to de/compress files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("file", help="A filename")
parser.add_argument(
"-o",
"--outfile",
help="Save to this filename",
)
parser.add_argument(
"-d",
"--decompress",
help="Decompress instead of compressing.",
action="store_true",
)
parser.add_argument(
"-l",
"--level",
help="Integer compression level. "
"Valid values are all negative integers through 22",
type=int,
default=3,
)
parser.add_argument(
"--override",
help="Allow overriding existing output files",
action="store_true",
)
parser.add_argument(
"--threads",
help="Number of threads to use to compress data concurrently. "
"0 disables multi-threaded compression. -1 means all logical CPUs",
type=int,
default=0,
)
parser.add_argument(
"--rm",
help="Remove source file after successful de/compression",
action="store_true",
)
return parser.parse_args(args)


def _check_args(args):
import pathlib

file = pathlib.Path(args.file)
if not file.exists() or not file.is_file():
raise FileNotFoundError(
f"File {args.file} does not exits or is not a file"
)
if args.outfile is None:
if args.decompress:
outfile = file.with_name(file.stem)
else:
outfile = file.with_name(f"{file.name}.zst")
else:
outfile = pathlib.Path(args.outfile)
if file.resolve() == outfile.resolve():
raise NotImplementedError(
"Overriding the input file is not supported."
"Please specify another output file"
)
if outfile.exists() and not args.override:
raise FileExistsError(
f"File {args.outfile} exists. Pass --override to override it"
)
return file, outfile


def _run(args):
import zstandard as zstd

file, outfile = _check_args(args)

in_fp, out_fp = None, None
try:
if args.decompress:
in_fp = zstd.open(file, "rb")
out_fp = open(outfile, "wb")
operation = "decompressing"
else:
in_fp = open(file, "rb")
ctx = zstd.ZstdCompressor(level=args.level, threads=args.threads)
out_fp = zstd.open(outfile, "wb", cctx=ctx)
operation = "compressing"
tot = 0
while True:
data = in_fp.read(CHUNK_SIZE)
if not data:
break
out_fp.write(data)
tot += len(data)
print(f"{operation} .. {tot//(1024*1024)} MB", end="\r")
print(" " * 100, end="\r")

finally:
if in_fp is not None:
in_fp.close()
if out_fp is not None:
out_fp.close()

outsize = outfile.stat().st_size
if args.decompress:
print(outfile.name, f": {outsize} bytes")
else:
insize = file.stat().st_size or 1
print(
args.file,
f": {outsize/insize*100:.2f}% ({insize} => {outsize} bytes)",
outfile.name,
)

if args.rm:
file.unlink()