Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions src/hats/pixel_math/partition_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Utilities for generating and manipulating object count histograms"""

import logging

import numpy as np
import pandas as pd
import pyarrow.parquet as pq

import hats.pixel_math.healpix_shim as hp

Expand Down Expand Up @@ -54,7 +57,12 @@ def generate_histogram(


def generate_alignment(
histogram, highest_order=10, lowest_order=0, threshold=1_000_000, drop_empty_siblings=False
histogram,
highest_order=10,
lowest_order=0,
threshold=1_000_000,
byte_pixel_threshold=None,
drop_empty_siblings=False,
):
"""Generate alignment from high order pixels to those of equal or lower order

Expand All @@ -66,20 +74,27 @@ def generate_alignment(
Args:
histogram (:obj:`np.array`): one-dimensional numpy array of long integers where the
value at each index corresponds to the number of objects found at the healpix pixel.
highest_order (int): the highest healpix order (e.g. 5-10)
highest_order (int): the highest healpix order (e.g. 5-10)
lowest_order (int): the lowest healpix order (e.g. 1-5). specifying a lowest order
constrains the partitioning to prevent spatially large pixels.
threshold (int): the maximum number of objects allowed in a single pixel
byte_pixel_threshold (int | None): the maximum number of objects allowed in a single pixel,
expressed in bytes. if this is set, it will override `threshold`.
drop_empty_siblings (bool): if 3 of 4 pixels are empty, keep only the non-empty pixel

Returns:
one-dimensional numpy array of integer 3-tuples, where the value at each index corresponds
to the destination pixel at order less than or equal to the `highest_order`.

The tuple contains three integers:

- order of the destination pixel
- pixel number *at the above order*
- the number of objects in the pixel
- the number of objects in the pixel (if partitioning by row count), or the memory size (if
partitioning by memory)

Note:
If partitioning is done by memory size, the row count per partition may vary widely and will
not match the row count histogram's bins.
Raises:
ValueError: if the histogram is the wrong size, or some initial histogram bins
exceed threshold.
Expand All @@ -88,9 +103,21 @@ def generate_alignment(
raise ValueError("histogram is not the right size")
if lowest_order > highest_order:
raise ValueError("lowest_order should be less than highest_order")

# Determine aggregation type and threshold
if byte_pixel_threshold is not None:
agg_threshold = byte_pixel_threshold
agg_type = "mem_size"
else:
agg_threshold = threshold
agg_type = "row_count"

# Check that none of the high-order pixels already exceed the threshold.
max_bin = np.amax(histogram)
if max_bin > threshold:
raise ValueError(f"single pixel count {max_bin} exceeds threshold {threshold}")
if agg_type == "mem_size" and max_bin > agg_threshold:
raise ValueError(f"single pixel size {max_bin} bytes exceeds byte_pixel_threshold {agg_threshold}")
if agg_type == "row_count" and max_bin > agg_threshold:
raise ValueError(f"single pixel count {max_bin} exceeds threshold {agg_threshold}")

nested_sums = []
for i in range(0, highest_order):
Expand All @@ -104,9 +131,10 @@ def generate_alignment(
parent_pixel = index >> 2
nested_sums[parent_order][parent_pixel] += nested_sums[read_order][index]

# Use the aggregation threshold for alignment
if drop_empty_siblings:
return _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, threshold)
return _get_alignment(nested_sums, highest_order, lowest_order, threshold)
return _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, agg_threshold)
return _get_alignment(nested_sums, highest_order, lowest_order, agg_threshold)


def _get_alignment(nested_sums, highest_order, lowest_order, threshold):
Expand All @@ -129,9 +157,9 @@ def _get_alignment(nested_sums, highest_order, lowest_order, threshold):

if parent_alignment:
nested_alignment[read_order][index] = parent_alignment
elif nested_sums[read_order][index] == 0:
elif nested_sums[read_order][index] == 0: # pylint: disable=no-else-raise
continue
elif nested_sums[read_order][index] <= threshold:
elif nested_sums[read_order][index] <= threshold: # pylint: disable=no-else-raise
nested_alignment[read_order][index] = (
read_order,
index,
Expand Down Expand Up @@ -201,3 +229,41 @@ def _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, t
]

return np.array(nested_alignment, dtype="object")


def generate_row_count_histogram_from_partitions(partition_files, pixel_orders, pixel_indices, highest_order):
"""Generate a row count histogram from a list of partition files and their pixel indices/orders.

Args:
partition_files (list[str or UPath]): List of paths to partition files.
pixel_orders (list[int]): List of healpix orders for each partition.
pixel_indices (list[int]): List of healpix pixel indices for each partition.
highest_order (int): The highest healpix order (for histogram size).

Returns:
np.ndarray: One-dimensional numpy array of long integers, where the value at each index
corresponds to the number of rows found at the healpix pixel.

Note:
If partitioning was done by memory size, this histogram will reflect the actual row counts
in the output partitions, which may differ significantly from the original row count histogram.
"""
histogram = np.zeros(hp.order2npix(highest_order), dtype=np.int64)
for file_path, order, pix in zip(partition_files, pixel_orders, pixel_indices):
try:
table = pq.read_table(file_path)
row_count = len(table)
# Map pixel index to highest_order if needed
if order == highest_order:
histogram[pix] += row_count
else:
# Map lower order pixel to highest_order pixel indices
# Each lower order pixel covers 4**(highest_order - order) pixels
factor = 4 ** (highest_order - order)
start = pix * factor
end = (pix + 1) * factor
histogram[start:end] += row_count // factor
except (OSError, ValueError) as e:
logging.warning("Could not read partition file %s: %s", file_path, e)
continue
return histogram