Skip to content

Commit

Permalink
perf: faster cropping; parallel imread/crop; mt_loop and mp_loop help…
Browse files Browse the repository at this point in the history
…er functions (#43)
  • Loading branch information
yxlao authored Mar 14, 2024
1 parent 7deebf8 commit 7506d36
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 16 deletions.
65 changes: 60 additions & 5 deletions camtools/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def crop_white_boarders(im: np.array, padding: Tuple[int] = (0, 0, 0, 0)) -> np.
return im_dst


def compute_cropping(im: np.array) -> Tuple[int]:
def compute_cropping_v1(im: np.array) -> Tuple[int]:
"""
Compute top, bottom, left, right white boarder in pixels.
Expand Down Expand Up @@ -76,6 +76,61 @@ def compute_cropping(im: np.array) -> Tuple[int]:
return crop_t, crop_b, crop_l, crop_r


def compute_cropping(
im: np.ndarray,
check_with_v1=False,
) -> Tuple[int, int, int, int]:
"""
Compute top, bottom, left, right white borders in pixels for a 3-channel
image.
This function is designed for (H, W, 3) images, where each pixel's value
ranges from 0.0 to 1.0, and white pixels are represented by (1.0, 1.0, 1.0).
Args:
im (np.ndarray): Input image as a NumPy array of shape (H, W, 3) and
dtype float32.
Returns:
Tuple[int, int, int, int]: A tuple containing the number of white pixels
to crop from the top, bottom, left, and right edges, respectively.
"""
if not im.dtype == np.float32:
raise ValueError(f"Expected im.dtype to be np.float32, but got {im.dtype}")
if im.ndim != 3 or im.shape[2] != 3:
raise ValueError(f"Expected im to be of shape (H, W, 3), but got {im.shape}")

# Create a mask where white pixels are marked as True
white_mask = np.all(im == 1.0, axis=-1)

# Find the indices of rows and columns where there's at least one non-white pixel
rows_with_color = np.where(~white_mask.all(axis=1))[0]
cols_with_color = np.where(~white_mask.all(axis=0))[0]

# Determine the crop values based on the positions of non-white pixels
crop_t = rows_with_color[0] if len(rows_with_color) else 0
crop_b = im.shape[0] - rows_with_color[-1] - 1 if len(rows_with_color) else 0
crop_l = cols_with_color[0] if len(cols_with_color) else 0
crop_r = im.shape[1] - cols_with_color[-1] - 1 if len(cols_with_color) else 0

# Check the results against compute_cropping_v1 if requested
if check_with_v1:
crop_t_v1, crop_b_v1, crop_l_v1, crop_r_v1 = compute_cropping_v1(im)
if (
crop_t != crop_t_v1
or crop_b != crop_b_v1
or crop_l != crop_l_v1
or crop_r != crop_r_v1
):
raise ValueError(
f"compute_cropping_v1 failed to compute the correct cropping: "
f"({crop_t}, {crop_b}, {crop_l}, {crop_r}) != "
f"({crop_t_v1}, {crop_b_v1}, {crop_l_v1}, {crop_r_v1})"
)

return crop_t, crop_b, crop_l, crop_r


def apply_cropping_padding(
im_src: np.ndarray,
cropping: Tuple[int],
Expand Down Expand Up @@ -114,12 +169,12 @@ def apply_cropping_padding(
return im_dst


def apply_croppings_paddings(im_srcs, croppings, paddings):
def apply_croppings_paddings(src_ims, croppings, paddings):
"""
Apply cropping and padding to a list of images.
Args:
im_srcs: list of (H, W, 3) images, float32.
src_ims: list of (H, W, 3) images, float32.
croppings: list of 4-tuples
[
(crop_t, crop_b, crop_l, crop_r),
Expand All @@ -133,7 +188,7 @@ def apply_croppings_paddings(im_srcs, croppings, paddings):
...
]
"""
num_ims = len(im_srcs)
num_ims = len(src_ims)
if not len(croppings) == num_ims:
raise ValueError(f"len(croppings) == {len(croppings)} != {num_ims}")
if not len(paddings) == num_ims:
Expand All @@ -143,7 +198,7 @@ def apply_croppings_paddings(im_srcs, croppings, paddings):
raise ValueError(f"len(cropping) == {len(cropping)} != 4")

dst_ims = []
for im_src, cropping, padding in zip(im_srcs, croppings, paddings):
for im_src, cropping, padding in zip(src_ims, croppings, paddings):
im_dst = apply_cropping_padding(im_src, cropping, padding)
dst_ims.append(im_dst)

Expand Down
21 changes: 10 additions & 11 deletions camtools/tools/crop_boarders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
```
"""

from pathlib import Path
import argparse
from pathlib import Path

import numpy as np
import camtools as ct
from tqdm import tqdm

import camtools as ct


def instantiate_parser(parser):
parser.add_argument(
Expand Down Expand Up @@ -106,10 +108,8 @@ def entry_point(parser, args):
]

# Read.
src_ims = [
ct.io.imread(src_path, alpha_mode="white")
for src_path in tqdm(src_paths, desc="Reading images")
]
src_ims = ct.utility.mt_loop(ct.io.imread, src_paths, alpha_mode="white")

for src_im in src_ims:
if not src_im.dtype == np.float32:
raise ValueError(f"Input image {src_path} must be of dtype float32.")
Expand All @@ -126,10 +126,9 @@ def entry_point(parser, args):
"All images must be of the same shape when --same_crop is " "specified."
)

individual_croppings = [
ct.image.compute_cropping(im)
for im in tqdm(src_ims, desc="Computing croppings")
]
individual_croppings = ct.utility.mt_loop(ct.image.compute_cropping, src_ims)

# Compute the minimum cropping boarders.
min_crop_u, min_crop_d, min_crop_l, min_crop_r = individual_croppings[0]
for crop_u, crop_d, crop_l, crop_r in individual_croppings[1:]:
min_crop_u = min(min_crop_u, crop_u)
Expand All @@ -147,7 +146,7 @@ def entry_point(parser, args):
paddings = [(padding, padding, padding, padding)] * len(src_ims)
else:
# Compute cropping boarders.
croppings = [ct.image.compute_cropping(src_im) for src_im in src_ims]
croppings = ct.utility.mt_loop(ct.image.compute_cropping, src_ims)

# Compute paddings.
if args.pad_pixel != 0:
Expand Down
68 changes: 68 additions & 0 deletions camtools/utility.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,71 @@
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import Any, Callable, Iterable

from tqdm import tqdm


def mt_loop(
func: Callable[[Any], Any],
inputs: Iterable[Any],
**kwargs,
) -> list:
"""
Applies a function to each item in the given list in parallel using multi-threading.
Args:
func (Callable[[Any], Any]): The function to apply. Must accept a single
argument.
inputs (Iterable[Any]): An iterable of inputs to process with the
function.
**kwargs: Additional keyword arguments to pass to `func`.
Returns:
list: A list of results from applying `func` to each item in `list_input`.
"""
desc = f"[mt] {func.__name__}"
with ThreadPoolExecutor() as executor:
futures = [executor.submit(func, item, **kwargs) for item in inputs]
progress = tqdm(
as_completed(futures),
total=len(inputs),
desc=desc,
)
results = [future.result() for future in progress]
return results


def mp_loop(
func: Callable[[Any], Any],
inputs: Iterable[Any],
**kwargs,
) -> list:
"""
Applies a function to each item in the given list in parallel using multi-processing.
Args:
func (Callable[[Any], Any]): The function to apply. Must accept a single
argument.
inputs (Iterable[Any]): An iterable of inputs to process with the
function.
**kwargs: Additional keyword arguments to pass to `func`.
Returns:
list: A list of results from applying `func` to each item in `inputs`.
"""
desc = f"[mp] {func.__name__}"
with ProcessPoolExecutor() as executor:
future_to_item = {
executor.submit(func, item, **kwargs): item for item in inputs
}
progress = tqdm(
as_completed(future_to_item),
total=len(inputs),
desc=desc,
)
results = [future.result() for future in progress]
return results


def query_yes_no(question, default=None):
"""Ask a yes/no question via raw_input() and return their answer.
Expand Down

0 comments on commit 7506d36

Please sign in to comment.