diff --git a/src/bmaptool/BmapCopy.py b/src/bmaptool/BmapCopy.py index f8f0a2f..368ad4e 100644 --- a/src/bmaptool/BmapCopy.py +++ b/src/bmaptool/BmapCopy.py @@ -213,16 +213,22 @@ class BmapCopy(object): instance. """ - def __init__(self, image, dest, bmap=None, image_size=None): + def __init__(self, image, dest, bmap=None, image_size=None, checksum_retry=None): """ The class constructor. The parameters are: - image - file-like object of the image which should be copied, - should only support 'read()' and 'seek()' methods, - and only seeking forward has to be supported. - dest - file object of the destination file to copy the image - to. - bmap - file object of the bmap file to use for copying. - image_size - size of the image in bytes. + image - file-like object of the image which should be copied, + should only support 'read()' and 'seek()' methods, + and only seeking forward has to be supported. + dest - file object of the destination file to copy the image + to. + bmap - file object of the bmap file to use for copying. + image_size - size of the image in bytes. + checksum_retry - number of retries for verifying written bmap ranges. + If set to a value > 0, each written mapped range + described in the bmap will be read back and its checksum + verified against the expected value from the bmap file. + If a mismatch is detected for a range, the entire range + will be rewritten up to checksum_retry times. """ self._xml = None @@ -270,6 +276,10 @@ def __init__(self, image, dest, bmap=None, image_size=None): self._cs_attrib_name = None self._bmap_cs_attrib_name = None + # Checksum retry configuration for post-write verification + self._checksum_retry = checksum_retry + self._warned_missing_checksum = False + # Special quirk for /dev/null which does not support fsync() if ( stat.S_ISCHR(st_data.st_mode) @@ -615,10 +625,12 @@ def _get_batches(self, first, last): def _get_data(self, verify): """ This is generator which reads the image file in '_batch_blocks' chunks - and yields ('type', 'start', 'end', 'buf) tuples, where: + and yields ('type', 'start', 'end', 'buf', 'range_first', 'range_last', 'range_chksum') + tuples, where: * 'start' is the starting block number of the batch; * 'end' is the last block of the batch; - * 'buf' a buffer containing the batch data. + * 'buf' a buffer containing the batch data; + * 'range_first', 'range_last', 'range_chksum' are the bmap range info for post-write verification. """ _log.debug("the reader thread has started") @@ -655,7 +667,7 @@ def _get_data(self, verify): % (blocks, self._batch_queue.qsize()) ) - self._batch_queue.put(("range", start, start + blocks - 1, buf)) + self._batch_queue.put(("range", start, start + blocks - 1, buf, first, last, chksum)) if verify and chksum and hash_obj.hexdigest() != chksum: raise Error( @@ -673,6 +685,153 @@ def _get_data(self, verify): self._batch_queue.put(None) + def _verify_written_blocks(self, start, end, expected_chksum): + """ + Verify the checksum of blocks that were written to the destination file. + Returns True if the checksum matches, False otherwise. + + Args: + start - starting block number + end - ending block number + expected_chksum - expected checksum string + + Returns: + True if checksum matches, False if mismatch + """ + if not self._cs_type or not expected_chksum: + return True + + try: + # Calculate how many bytes to verify + # For the last block range, we may have a partial block + byte_start = start * self.block_size + byte_end = (end + 1) * self.block_size + + # Limit to actual image size for partial last block + if self.image_size and byte_end > self.image_size: + byte_end = self.image_size + + bytes_to_verify = byte_end - byte_start + + self._f_dest.seek(byte_start) + buf = self._f_dest.read(bytes_to_verify) + + # Short read is a mismatch (destination doesn't have all expected data) + if len(buf) != bytes_to_verify: + _log.warning( + "short read while verifying blocks %d-%d: expected %d bytes, got %d", + start, end, bytes_to_verify, len(buf) + ) + return False + + hash_obj = hashlib.new(self._cs_type) + hash_obj.update(buf) + + calculated = hash_obj.hexdigest() + return calculated == expected_chksum + except IOError as err: + _log.error( + "error while verifying blocks %d-%d of '%s': %s", + start, end, self._dest_path, err + ) + return False + + def _drop_cached_blocks(self, start, end): + """ + Best-effort invalidation of destination page cache for block range. + This helps make checksum verification read data back from storage + rather than returning cached pages. + """ + + if not hasattr(os, "posix_fadvise") or not hasattr( + os, "POSIX_FADV_DONTNEED" + ): + return + + offset = start * self.block_size + length = (end - start + 1) * self.block_size + + # Limit length to actual image size for partial last block + if self.image_size and offset + length > self.image_size: + length = self.image_size - offset + + try: + os.posix_fadvise( + self._f_dest.fileno(), offset, length, os.POSIX_FADV_DONTNEED + ) + except OSError as err: + _log.debug( + "cannot drop page cache for blocks %d-%d of '%s': %s", + start, + end, + self._dest_path, + err, + ) + + def _verify_range_with_retry(self, range_first, range_last, range_chksum, range_buffers): + """ + Verify a block range's checksum after writing. If verification fails, + retry writing the range up to self._checksum_retry times. + + Args: + range_first - first block of the range + range_last - last block of the range + range_chksum - expected checksum for the range + range_buffers - dict of {(start, end): buf} for blocks in this range + """ + if self._checksum_retry and not range_chksum and not self._warned_missing_checksum: + _log.warning( + "checksum-retry requested but bmap file does not contain checksums; " + "skipping verification for blocks %d-%d and beyond" + % (range_first, range_last) + ) + self._warned_missing_checksum = True + + if not range_chksum or not self._checksum_retry: + return + + retry_count = 0 + retry_limit = int(self._checksum_retry) + + while True: + # Sync to disk before reading back for verification + # This ensures we verify data that has actually reached the physical + # disk, not just what's in the kernel's page cache + self.sync() + self._drop_cached_blocks(range_first, range_last) + + # Verify the entire range + if self._verify_written_blocks(range_first, range_last, range_chksum): + _log.debug( + "checksum verification passed for blocks %d-%d" + % (range_first, range_last) + ) + return + + # Checksum mismatch - retry + retry_count += 1 + if retry_count > retry_limit: + raise Error( + "checksum verification failed for blocks %d-%d after " + "%d retries" % (range_first, range_last, retry_limit) + ) + + _log.warning( + "checksum mismatch for blocks %d-%d, retrying " + "(attempt %d/%d)" % (range_first, range_last, retry_count, retry_limit) + ) + + # Re-write all blocks in this range + for (start, end), buf in range_buffers.items(): + try: + self._f_dest.seek(start * self.block_size) + self._f_dest.write(buf) + except IOError as err: + raise Error( + "error while writing blocks %d-%d of '%s': %s" + % (start, end, self._dest_path, err) + ) + def copy(self, sync=True, verify=True): """ Copy the image to the destination file using bmap. The 'sync' argument @@ -704,10 +863,19 @@ def copy(self, sync=True, verify=True): # Read the image in '_batch_blocks' chunks and write them to the # destination file + range_buffers = {} # Track buffers for each range to enable retry + current_range = None + while True: batch = self._batch_queue.get() if batch is None: # No more data, the image is written + # Verify any remaining range + if self._checksum_retry and current_range: + range_first, range_last, range_chksum = current_range + self._verify_range_with_retry( + range_first, range_last, range_chksum, range_buffers + ) break elif batch[0] == "error": # The reader thread encountered an error and passed us the @@ -715,11 +883,24 @@ def copy(self, sync=True, verify=True): exc_info = batch[1] raise exc_info[1] - (start, end, buf) = batch[1:4] + (start, end, buf, range_first, range_last, range_chksum) = batch[1:7] assert len(buf) <= (end - start + 1) * self.block_size assert len(buf) > (end - start) * self.block_size + # Check if we've moved to a new range - if so, verify the previous one + new_range = (range_first, range_last, range_chksum) + if self._checksum_retry and current_range and new_range != current_range: + # Verify the completed range before moving to the next one + prev_first, prev_last, prev_chksum = current_range + self._verify_range_with_retry( + prev_first, prev_last, prev_chksum, range_buffers + ) + # Reset for new range + range_buffers = {} + + current_range = new_range + self._f_dest.seek(start * self.block_size) # Synchronize the destination file if we reached the watermark @@ -740,6 +921,10 @@ def copy(self, sync=True, verify=True): blocks_written += end - start + 1 bytes_written += len(buf) + # Track buffers for this range (for retry if needed) + if self._checksum_retry: + range_buffers[(start, end)] = buf + self._update_progress(blocks_written) if not self.image_size: @@ -785,7 +970,12 @@ def sync(self): if self._dest_supports_fsync: try: - os.fsync(self._f_dest.fileno()), + self._f_dest.flush() + except IOError as err: + raise Error("cannot flush '%s': %s" % (self._dest_path, err)) + + try: + os.fsync(self._f_dest.fileno()) except OSError as err: raise Error( "cannot synchronize '%s': %s " % (self._dest_path, err.strerror) @@ -800,14 +990,14 @@ class BmapBdevCopy(BmapCopy): scheduler. """ - def __init__(self, image, dest, bmap=None, image_size=None): + def __init__(self, image, dest, bmap=None, image_size=None, checksum_retry=None): """ The same as the constructor of the 'BmapCopy' base class, but adds useful guard-checks specific to block devices. """ # Call the base class constructor first - BmapCopy.__init__(self, image, dest, bmap, image_size) + BmapCopy.__init__(self, image, dest, bmap, image_size, checksum_retry) self._dest_fsync_watermark = (6 * 1024 * 1024) // self.block_size diff --git a/src/bmaptool/CLI.py b/src/bmaptool/CLI.py index 40acf31..01b2dca 100644 --- a/src/bmaptool/CLI.py +++ b/src/bmaptool/CLI.py @@ -104,7 +104,7 @@ def __getattr__(self, name): return getattr(self._file_obj, name) -def open_block_device(path): +def open_block_device(path, need_read_access=False): """ This is a helper function for 'open_files()' which is called if the destination file of the "copy" command is a block device. We handle block @@ -115,17 +115,25 @@ def open_block_device(path): that we are the only users of the block device. This function opens a block device specified by 'path' in exclusive mode. + The 'need_read_access' parameter controls whether the device is opened + with read+write (True) or write-only (False) permissions. Returns opened file object. """ try: - descriptor = os.open(path, os.O_WRONLY | os.O_EXCL) + if need_read_access: + descriptor = os.open(path, os.O_RDWR | os.O_EXCL) + else: + descriptor = os.open(path, os.O_WRONLY | os.O_EXCL) except OSError as err: error_out("cannot open block device '%s' in exclusive mode: %s", path, err) # Turn the block device file descriptor into a file object try: - file_obj = os.fdopen(descriptor, "wb") + if need_read_access: + file_obj = os.fdopen(descriptor, "r+b") + else: + file_obj = os.fdopen(descriptor, "wb") except OSError as err: os.close(descriptor) error_out("cannot open block device '%s':\n%s", path, err) @@ -586,7 +594,9 @@ def open_files(args): try: if pathlib.Path(args.dest).is_block_device(): dest_is_blkdev = True - dest_obj = open_block_device(args.dest) + # Request read access only if checksum-retry is enabled + need_read = bool(args.checksum_retry) + dest_obj = open_block_device(args.dest, need_read_access=need_read) else: dest_obj = open(args.dest, "wb+") except IOError as err: @@ -610,6 +620,22 @@ def copy_command(args): if args.no_sig_verify and args.fingerprint: error_out("--no-sig-verify and --fingerprint cannot be used together") + # Validate checksum_retry argument + checksum_retry = None + if args.checksum_retry is not None: + try: + checksum_retry = int(args.checksum_retry) + if checksum_retry < 1: + error_out("--checksum-retry argument must be a positive integer") + except ValueError: + error_out("--checksum-retry argument must be a valid integer") + + if checksum_retry: + log.info( + "checksum verification of written blocks enabled with up to %d retry attempts" + % checksum_retry + ) + image_obj, dest_obj, bmap_obj, bmap_path, image_size, dest_is_blkdev = open_files( args ) @@ -631,10 +657,14 @@ def copy_command(args): if dest_is_blkdev: dest_str = "block device '%s'" % args.dest # For block devices, use the specialized class - writer = BmapCopy.BmapBdevCopy(image_obj, dest_obj, bmap_obj, image_size) + writer = BmapCopy.BmapBdevCopy( + image_obj, dest_obj, bmap_obj, image_size, checksum_retry + ) else: dest_str = "file '%s'" % os.path.basename(args.dest) - writer = BmapCopy.BmapCopy(image_obj, dest_obj, bmap_obj, image_size) + writer = BmapCopy.BmapCopy( + image_obj, dest_obj, bmap_obj, image_size, checksum_retry + ) except BmapCopy.Error as err: error_out(err) @@ -855,6 +885,19 @@ def parse_arguments(): text = "do not verify the data checksum while writing" parser_copy.add_argument("--no-verify", action="store_true", help=text) + # The --checksum-retry option + text = ( + "verify checksums of written blocks and retry writing on mismatch " + "(optional: number of retries, default 1)" + ) + parser_copy.add_argument( + "--checksum-retry", + nargs="?", + const="1", + type=str, + help=text, + ) + # The --psplash-pipe option text = "write progress to a psplash pipe" parser_copy.add_argument("--psplash-pipe", help=text) diff --git a/tests/test_checksum_retry.py b/tests/test_checksum_retry.py new file mode 100644 index 0000000..c51775a --- /dev/null +++ b/tests/test_checksum_retry.py @@ -0,0 +1,506 @@ +# -*- coding: utf-8 -*- +# vim: ts=4 sw=4 tw=88 et ai si +# +# License: GPLv2 +# +# This program is free software; you can redistribute it and/or modify it under +# the terms of the GNU General Public License, version 2 or any later version, +# as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. + +""" +This test verifies the checksum-retry functionality, which retries writing +blocks if their checksums don't match after being written. +""" + +# Disable the following pylint recommendations: +# * Too many public methods (R0904) +# * Too many local variables (R0914) +# pylint: disable=R0904 +# pylint: disable=R0914 + +import os +import tempfile +from bmaptool import BmapCreate, BmapCopy, TransRead, BmapHelpers + +# This is a work-around for Centos 6 +try: + import unittest2 as unittest # pylint: disable=F0401 +except ImportError: + import unittest + + +class ChecksumMismatchFile: + """ + A wrapper around a file object that can simulate write failures + by corrupting specific byte ranges on read-back. + + Modes: + 'once': Corrupt on first read, then return clean data + 'always': Corrupt on every read + """ + + def __init__(self, file_obj, corruption_ranges=None, mode='once'): + """ + Initialize with a file object and optional list of byte ranges to corrupt. + + Args: + file_obj: underlying file object + corruption_ranges: list of (start_byte, end_byte) tuples to corrupt on read + mode: 'once' to corrupt only on first read, 'always' to always corrupt + """ + self._file = file_obj + self.corruption_ranges = corruption_ranges or [] + self.mode = mode + self._corruption_count = 0 # Track how many times we've applied corruption + + def write(self, data): + """Pass through write to underlying file.""" + return self._file.write(data) + + def read(self, size=-1): + """Read from file, optionally corrupting specified ranges.""" + pos = self._file.tell() + data = bytearray(self._file.read(size)) + + should_corrupt = False + if self.mode == 'once': + should_corrupt = self._corruption_count == 0 + elif self.mode == 'always': + should_corrupt = True + + if should_corrupt and self.corruption_ranges: + # Apply corruptions to matching ranges + for start, end in self.corruption_ranges: + # Check if any part of this read overlaps with corruption range + read_end = pos + len(data) + if start < read_end and end > pos: + # Calculate overlap + overlap_start = max(0, start - pos) + overlap_end = min(len(data), end - pos) + # Flip bits in the overlapping region + for i in range(overlap_start, overlap_end): + data[i] ^= 0xFF + self._corruption_count += 1 + + return bytes(data) + + def seek(self, pos, whence=0): + """Pass through seek to underlying file.""" + return self._file.seek(pos, whence) + + def tell(self): + """Pass through tell to underlying file.""" + return self._file.tell() + + def flush(self): + """Pass through flush to underlying file.""" + return self._file.flush() + + def fileno(self): + """Pass through fileno to underlying file.""" + return self._file.fileno() + + @property + def name(self): + """Pass through name from underlying file.""" + return self._file.name + + def close(self): + """Pass through close to underlying file.""" + return self._file.close() + + +def _create_test_image_with_bmap(image_size, directory=None): + """ + Create a test image and its corresponding bmap file with checksums. + Returns (image_path, bmap_path) + + The bmap file includes checksums for all ranges, which are verified + during copy operations, so no separate whole-image checksum is needed. + """ + # Create image file with some data + f_image = tempfile.NamedTemporaryFile( + "wb+", prefix="test_image_", delete=False, dir=directory, suffix=".img" + ) + image_path = f_image.name + + # Fill image with deterministic data + chunk_size = 1024 + written = 0 + seed_byte = 0 + while written < image_size: + to_write = min(chunk_size, image_size - written) + # Create deterministic pattern + data = bytes([(seed_byte + i) % 256 for i in range(to_write)]) + f_image.write(data) + written += to_write + seed_byte = (seed_byte + to_write) % 256 + + f_image.flush() + f_image.close() + + # Create bmap for the image + f_bmap = tempfile.NamedTemporaryFile( + "w+", prefix="test_image_", delete=False, dir=directory, suffix=".bmap" + ) + bmap_path = f_bmap.name + f_bmap.close() + + creator = BmapCreate.BmapCreate(image_path, bmap_path) + creator.generate(include_checksums=True) + + return image_path, bmap_path + + +class TestChecksumRetry(unittest.TestCase): + """Test the checksum-retry functionality.""" + + def test_checksum_retry_aligned_image(self): + """Test checksum-retry with image size aligned to block size.""" + # 8192 bytes = 2 blocks of 4096 bytes + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + # Test with checksum-retry enabled + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=1 + ) + # copy() verifies each bmap range's checksum during operation + writer.copy(sync=True, verify=True) + + # Verify destination has correct size + dest_size = os.path.getsize(dest_path) + self.assertEqual(dest_size, image_size) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_retry_non_aligned_image(self): + """Test checksum-retry with image size NOT aligned to block size.""" + # 9000 bytes = 2.19... blocks of 4096 bytes (non-aligned) + image_size = 9000 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=1 + ) + # copy() verifies each bmap range's checksum during operation + writer.copy(sync=True, verify=True) + + # Verify destination has correct size + dest_size = os.path.getsize(dest_path) + self.assertEqual(dest_size, image_size) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_retry_disabled_vs_enabled(self): + """Test that checksum_retry=None (disabled) vs checksum_retry=1 (enabled) produce identical results.""" + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + dest_path_disabled = None + dest_path_enabled = None + + try: + # Test with checksum_retry=None (disabled - no verification performed) + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path_disabled = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=None + ) + writer.copy(sync=True, verify=True) + + f_image.close() + f_dest.close() + f_bmap.close() + + # Verify destination has correct size + size_disabled = os.path.getsize(dest_path_disabled) + self.assertEqual(size_disabled, image_size) + + # Test with checksum_retry=1 (enabled with 1 retry attempt) + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path_enabled = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=1 + ) + writer.copy(sync=True, verify=True) + + f_image.close() + f_dest.close() + f_bmap.close() + + # Verify destination has correct size + size_enabled = os.path.getsize(dest_path_enabled) + self.assertEqual(size_enabled, image_size) + + # Compare contents of both destination files - should be identical + with open(dest_path_disabled, "rb") as f: + content_disabled = f.read() + with open(dest_path_enabled, "rb") as f: + content_enabled = f.read() + + self.assertEqual(content_disabled, content_enabled, + "Disabled and enabled checksum_retry modes should produce identical output") + + finally: + if dest_path_disabled: + os.unlink(dest_path_disabled) + if dest_path_enabled: + os.unlink(dest_path_enabled) + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_retry_multiple_retries(self): + """Test checksum-retry with multiple retry attempts.""" + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + # Test with retry limit of 3 + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=3 + ) + # copy() verifies each bmap range's checksum during operation + writer.copy(sync=True, verify=True) + + # Verify destination has correct size + dest_size = os.path.getsize(dest_path) + self.assertEqual(dest_size, image_size) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_retry_without_bmap_checksums(self): + """Test checksum-retry when bmap file has no checksums.""" + image_size = 8192 + image_path = None + bmap_path = None + + try: + # Create image + f_image = tempfile.NamedTemporaryFile( + "wb+", prefix="test_image_", delete=False, suffix=".img" + ) + image_path = f_image.name + f_image.write(b"X" * image_size) + f_image.flush() + f_image.close() + + # Create bmap WITHOUT checksums + f_bmap = tempfile.NamedTemporaryFile( + "w+", prefix="test_image_", delete=False, suffix=".bmap" + ) + bmap_path = f_bmap.name + f_bmap.close() + + creator = BmapCreate.BmapCreate(image_path, bmap_path) + creator.generate(include_checksums=False) # No checksums! + + # Copy with checksum-retry enabled - should issue warning but not fail + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=1 + ) + # Should complete without error (warning is only logged) + writer.copy(sync=True, verify=True) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + if image_path: + os.unlink(image_path) + if bmap_path: + os.unlink(bmap_path) + + def test_checksum_retry_zero_value(self): + """Test that checksum_retry=0 or checksum_retry=False is treated as disabled.""" + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + # Test with retry=0 (should be treated as disabled) + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=0 + ) + writer.copy(sync=True, verify=True) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_retry_very_small_image(self): + """Test checksum-retry with very small image (1 byte, non-aligned).""" + image_size = 1 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + f_image = TransRead.TransRead(image_path) + f_dest = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest.name + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=1 + ) + writer.copy(sync=True, verify=True) + + # Verify the copy matches original + with open(dest_path, "rb") as f: + copy_data = f.read() + + self.assertEqual(len(copy_data), image_size) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_mismatch_triggers_rewrite(self): + """ + Test that a checksum mismatch on first verify triggers a rewrite, + which succeeds on the retry (after the wrapped file stops corrupting). + """ + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + f_image = TransRead.TransRead(image_path) + f_dest_real = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest_real.name + f_dest_real.close() + + # Open destination and wrap it to corrupt data on first read + f_dest_unwrapped = open(dest_path, "r+b") + # Corrupt the first block (bytes 0-4095) on first read + f_dest = ChecksumMismatchFile( + f_dest_unwrapped, corruption_ranges=[(0, 4096)], mode="once" + ) + + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=2 + ) + + # This should succeed: write will fail verification due to corruption, + # trigger a rewrite, and then succeed on retry (corruption only applies once) + writer.copy(sync=True, verify=True) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + def test_checksum_mismatch_exhausts_retries(self): + """ + Test that when checksum mismatches persist beyond retry limit, + an Error is raised. + """ + image_size = 8192 + image_path, bmap_path = _create_test_image_with_bmap(image_size) + + try: + f_image = TransRead.TransRead(image_path) + f_dest_real = tempfile.NamedTemporaryFile("w+b", delete=False) + dest_path = f_dest_real.name + f_dest_real.close() + + # Open destination and wrap it to corrupt data on EVERY read + f_dest_unwrapped = open(dest_path, "r+b") + # Corrupt the first block on every read (mode='always') + f_dest = ChecksumMismatchFile( + f_dest_unwrapped, + corruption_ranges=[(0, 4096)], + mode="always" # Always corrupt! + ) + + f_bmap = open(bmap_path, "r") + + writer = BmapCopy.BmapCopy( + f_image, f_dest, f_bmap, image_size, checksum_retry=2 + ) + + # This should raise Error because verification always fails + with self.assertRaises(BmapCopy.Error) as context: + writer.copy(sync=True, verify=True) + + # Verify the error message mentions checksum verification failure + self.assertIn("checksum verification failed", str(context.exception)) + + f_image.close() + f_dest.close() + f_bmap.close() + os.unlink(dest_path) + finally: + os.unlink(image_path) + os.unlink(bmap_path) + + +if __name__ == "__main__": + unittest.main()