diff --git a/tests/test_structure.py b/tests/test_structure.py index 8ff900d..fab0138 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -19,9 +19,13 @@ def test_baseline_id(na, auto_corrs): @pytest.mark.parametrize("simmed_ms", [{"name": "proxy.ms"}], indirect=True) def test_structure_factory(simmed_ms): + partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"] table_factory = TableFactory(Table.from_filename, simmed_ms) - structure_factory = MSv2StructureFactory(table_factory) + structure_factory = MSv2StructureFactory(table_factory, partition_columns) assert pickle.loads(pickle.dumps(structure_factory)) == structure_factory - structure_factory2 = MSv2StructureFactory(table_factory) + structure_factory2 = MSv2StructureFactory(table_factory, partition_columns) assert structure_factory() is structure_factory2() + + keys = tuple(k for kv in structure_factory().keys() for k, _ in kv) + assert tuple(sorted(partition_columns)) == keys diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index ab324ff..786a217 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -3,7 +3,7 @@ import os import warnings from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict, Iterable +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple from uuid import uuid4 import xarray @@ -33,19 +33,11 @@ from xarray_ms.backend.msv2.structure import PartitionKeyT -def table_factory_factory(ms: str, ninstances: int) -> TableFactory: - """ - Ensures consistency when creating a TableFactory. - Multiple calls to this method with the same argument will - resolve to the same cached instance. - """ - return TableFactory( - Table.from_filename, - ms, - ninstances=ninstances, - readonly=True, - lockoptions="nolock", - ) +DEFAULT_PARTITION_COLUMNS: List[str] = [ + "DATA_DESC_ID", + "FIELD_ID", + "OBSERVATION_ID", +] def promote_chunks( @@ -78,13 +70,54 @@ def promote_chunks( return return_chunks +def initialise_default_args( + ms: str, + ninstances: int, + auto_corrs: bool, + epoch: str | None, + table_factory: TableFactory | None, + partition_columns: List[str] | None, + partition_key: PartitionKeyT | None, + structure_factory: MSv2StructureFactory | None, +) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]: + """ + Ensures consistency when initialising default arguments from multiple locations + """ + if not os.path.exists(ms): + raise ValueError(f"MS {ms} does not exist") + + table_factory = table_factory or TableFactory( + Table.from_filename, + ms, + ninstances=ninstances, + readonly=True, + lockoptions="nolock", + ) + epoch = epoch or uuid4().hex[:8] + partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS + structure_factory = structure_factory or MSv2StructureFactory( + table_factory, partition_columns, auto_corrs=auto_corrs + ) + structure = structure_factory() + if partition_key is None: + partition_key = next(iter(structure.keys())) + warnings.warn( + f"No partition_key was supplied. Selected first partition {partition_key}" + ) + elif partition_key not in structure: + raise ValueError(f"{partition_key} not in {list(structure.keys())}") + + return epoch, table_factory, partition_columns, partition_key, structure_factory + + class MSv2Store(AbstractWritableDataStore): """Store for reading and writing MSv2 data""" __slots__ = ( "_table_factory", "_structure_factory", - "_partition", + "_partition_columns", + "_partition_key", "_auto_corrs", "_ninstances", "_epoch", @@ -92,6 +125,7 @@ class MSv2Store(AbstractWritableDataStore): _table_factory: TableFactory _structure_factory: MSv2StructureFactory + _partition_columns: List[str] _partition: PartitionKeyT _autocorrs: bool _ninstances: int @@ -101,14 +135,16 @@ def __init__( self, table_factory: TableFactory, structure_factory: MSv2StructureFactory, - partition: PartitionKeyT, + partition_columns: List[str], + partition_key: PartitionKeyT, auto_corrs: bool, ninstances: int, epoch: str, ): self._table_factory = table_factory self._structure_factory = structure_factory - self._partition = partition + self._partition_columns = partition_columns + self._partition_key = partition_key self._auto_corrs = auto_corrs self._ninstances = ninstances self._epoch = epoch @@ -118,7 +154,8 @@ def open( cls, ms: str, drop_variables=None, - partition: PartitionKeyT | None = None, + partition_columns: List[str] | None = None, + partition_key: PartitionKeyT | None = None, auto_corrs: bool = True, ninstances: int = 1, epoch: str | None = None, @@ -127,23 +164,24 @@ def open( if not isinstance(ms, str): raise ValueError("Measurement Sets paths must be strings") - table_factory = table_factory_factory(ms, ninstances) - epoch = epoch or uuid4().hex[:8] - structure_factory = structure_factory or MSv2StructureFactory( - table_factory, auto_corrs + epoch, table_factory, partition_columns, partition_key, structure_factory = ( + initialise_default_args( + ms, + ninstances, + auto_corrs, + epoch, + None, + partition_columns, + partition_key, + structure_factory, + ) ) - structure = structure_factory() - - if partition is None: - partition = next(iter(structure.keys())) - warnings.warn(f"No partition was supplied. Selected first partition {partition}") - elif partition not in structure: - raise ValueError(f"{partition} not in {list(structure.keys())}") return cls( table_factory, structure_factory, - partition=partition, + partition_columns=partition_columns, + partition_key=partition_key, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, @@ -154,12 +192,12 @@ def close(self, **kwargs): def get_variables(self): return MainDatasetFactory( - self._partition, self._table_factory, self._structure_factory + self._partition_key, self._table_factory, self._structure_factory ).get_variables() def get_attrs(self): try: - ddid = next(iter(v for k, v in self._partition if k == "DATA_DESC_ID")) + ddid = next(iter(v for k, v in self._partition_key if k == "DATA_DESC_ID")) except StopIteration: raise KeyError("DATA_DESC_ID not found in partition") @@ -183,7 +221,7 @@ def get_encoding(self): class MSv2PartitionEntryPoint(BackendEntrypoint): open_dataset_parameters = [ "filename_or_obj", - "partition", + "partition_columns" "partition_key", "auto_corrs", "ninstances", "epoch", @@ -212,14 +250,11 @@ def guess_can_open( def open_dataset( self, - filename_or_obj: str - | os.PathLike[Any] - | BufferedIOBase - | AbstractDataStore - | TableFactory, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, drop_variables: str | Iterable[str] | None = None, - partition=None, + partition_columns=None, + partition_key=None, auto_corrs=True, ninstances=8, epoch=None, @@ -229,7 +264,8 @@ def open_dataset( store = MSv2Store.open( filename_or_obj, drop_variables=drop_variables, - partition=partition, + partition_columns=partition_columns, + partition_key=partition_key, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, @@ -243,6 +279,7 @@ def open_datatree( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, drop_variables: str | Iterable[str] | None = None, + partition_columns=None, auto_corrs=True, ninstances=8, epoch=None, @@ -255,10 +292,11 @@ def open_datatree( else: raise ValueError("Measurement Set paths must be strings") - table_factory = table_factory_factory(ms, ninstances) - structure_factory = MSv2StructureFactory(table_factory, auto_corrs=auto_corrs) - structure = structure_factory() + epoch, _, partition_columns, _, structure_factory = initialise_default_args( + ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None + ) + structure = structure_factory() datasets = {} chunks = kwargs.pop("chunks", None) pchunks = promote_chunks(structure, chunks) @@ -267,7 +305,8 @@ def open_datatree( ds = xarray.open_dataset( ms, drop_variables=drop_variables, - partition=partition_key, + partition_columns=partition_columns, + partition_key=partition_key, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index d92bb85..b590119 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -103,17 +103,10 @@ def is_partition_key(key: PartitionKeyT) -> bool: ) -PARTITION_COLUMNS: List[str] = [ - "DATA_DESC_ID", - "FIELD_ID", - "PROCESSOR_ID", - "FEED1", - "FEED2", -] - SHORT_TO_LONG_PARTITION_COLUMNS: Dict[str, str] = { "D": "DATA_DESC_ID", "F": "FIELD_ID", + "O": "OBSERVATION_ID", "P": "PROCESSOR_ID", "F1": "FEED1", "F2": "FEED2", @@ -145,10 +138,14 @@ class PartitionData: class TablePartitioner: """Partitions and sorts MSv2 indexing columns""" + _partitionby: List[str] _sortby: List[str] _other: List[str] - def __init__(self, sortby: Sequence[str], other: Sequence[str]): + def __init__( + self, partitionby: Sequence[str], sortby: Sequence[str], other: Sequence[str] + ): + self._partitionby = list(partitionby) self._sortby = list(sortby) self._other = list(other) @@ -170,9 +167,9 @@ def partition(self, index: pa.Table) -> Dict[PartitionKeyT, pa.Table]: raise ValueError(f"{read_columns} is not a subset of {index.column_names}") agg_cmd = [ - (c, "list") for c in (maybe_row | set(read_columns) - set(PARTITION_COLUMNS)) + (c, "list") for c in (maybe_row | set(read_columns) - set(self._partitionby)) ] - partitions = index.group_by(PARTITION_COLUMNS).aggregate(agg_cmd) + partitions = index.group_by(self._partitionby).aggregate(agg_cmd) renames = {f"{c}_list": c for c, _ in agg_cmd} partitions = partitions.rename_columns( renames.get(c, c) for c in partitions.column_names @@ -182,7 +179,7 @@ def partition(self, index: pa.Table) -> Dict[PartitionKeyT, pa.Table]: for p in range(len(partitions)): key: PartitionKeyT = tuple( - sorted((c, int(partitions[c][p].as_py())) for c in PARTITION_COLUMNS) + sorted((c, int(partitions[c][p].as_py())) for c in self._partitionby) ) table_dict = {c: partitions[c][p].values for c in read_columns | maybe_row} partition_table = pa.Table.from_pydict(table_dict) @@ -204,13 +201,17 @@ class MSv2StructureFactory: for creating and caching an MSv2Structure""" _ms_factory: TableFactory + _partition_columns: List[str] _auto_corrs: bool _STRUCTURE_CACHE: ClassVar[Cache] = Cache( maxsize=100, ttl=60, on_get=on_get_keep_alive ) - def __init__(self, ms: TableFactory, auto_corrs: bool = True): + def __init__( + self, ms: TableFactory, partition_columns: List[str], auto_corrs: bool = True + ): self._ms_factory = ms + self._partition_columns = partition_columns self._auto_corrs = auto_corrs def __eq__(self, other: Any) -> bool: @@ -218,19 +219,27 @@ def __eq__(self, other: Any) -> bool: return NotImplemented return ( - self._ms_factory == other._ms_factory and self._auto_corrs == other._auto_corrs + self._ms_factory == other._ms_factory + and self._partition_columns == other._partition_columns + and self._auto_corrs == other._auto_corrs ) def __hash__(self): - return hash((self._ms_factory, self._auto_corrs)) + return hash((self._ms_factory, tuple(self._partition_columns), self._auto_corrs)) def __reduce__(self): - return (MSv2StructureFactory, (self._ms_factory, self._auto_corrs)) + return ( + MSv2StructureFactory, + (self._ms_factory, self._partition_columns, self._auto_corrs), + ) def __call__(self, *args, **kw) -> MSv2Structure: assert not args and not kw return self._STRUCTURE_CACHE.get( - self, lambda self: MSv2Structure(self._ms_factory, self._auto_corrs) + self, + lambda self: MSv2Structure( + self._ms_factory, self._partition_columns, self._auto_corrs + ), ) @@ -239,6 +248,7 @@ class MSv2Structure(Mapping): _ms_factory: TableFactory _auto_corrs: bool + _partition_columns: List[str] _partitions: Mapping[PartitionKeyT, PartitionData] _column_descs: Dict[str, Dict[str, ColumnDesc]] _ant: pa.Table @@ -305,7 +315,7 @@ def resolve_key(self, key: str | PartitionKeyT) -> List[PartitionKeyT]: if isinstance(key, str): key = self.parse_partition_key(key) - column_set = set(PARTITION_COLUMNS) + column_set = set(self._partition_columns) # Check that the key columns and values are valid new_key: List[Tuple[str, int]] = [] @@ -314,7 +324,8 @@ def resolve_key(self, key: str | PartitionKeyT) -> List[PartitionKeyT]: column = SHORT_TO_LONG_PARTITION_COLUMNS.get(column, column) if column not in column_set: raise InvalidPartitionKey( - f"{column} is not valid a valid partition column " f"{PARTITION_COLUMNS}" + f"{column} is not valid a valid partition column " + f"{self._partition_columns}" ) if not isinstance(value, Integral): raise InvalidPartitionKey(f"{value} is not a valid partition value") @@ -331,12 +342,18 @@ def resolve_key(self, key: str | PartitionKeyT) -> List[PartitionKeyT]: return matches - def __init__(self, ms: TableFactory, auto_corrs: bool = True): + def __init__( + self, ms: TableFactory, partition_columns: List[str], auto_corrs: bool = True + ): import time as modtime start = modtime.time() + if "DATA_DESC_ID" not in partition_columns: + raise ValueError("DATA_DESC_ID must be included as a partitioning column") + self._ms_factory = ms + self._partition_columns = partition_columns self._auto_corrs = auto_corrs table = ms() @@ -385,11 +402,11 @@ def __init__(self, ms: TableFactory, auto_corrs: bool = True): self._column_descs = FrozenDict(col_descs) other_columns = ["INTERVAL"] - read_columns = set(PARTITION_COLUMNS) | set(SORT_COLUMNS) | set(other_columns) + read_columns = set(partition_columns) | set(SORT_COLUMNS) | set(other_columns) index = table.to_arrow(columns=read_columns) - partitions = TablePartitioner(SORT_COLUMNS, other_columns + ["row"]).partition( - index - ) + partitions = TablePartitioner( + partition_columns, SORT_COLUMNS, other_columns + ["row"] + ).partition(index) self._partitions = {}