Skip to content

Commit

Permalink
Merge pull request #3 from ratt-ru/configurable-partitioning-columns
Browse files Browse the repository at this point in the history
Make partitioning columns configurable
  • Loading branch information
sjperkins committed Sep 10, 2024
2 parents 7582130 + a01b9ba commit cff65e3
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 70 deletions.
8 changes: 6 additions & 2 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ def test_baseline_id(na, auto_corrs):

@pytest.mark.parametrize("simmed_ms", [{"name": "proxy.ms"}], indirect=True)
def test_structure_factory(simmed_ms):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory)
structure_factory = MSv2StructureFactory(table_factory, partition_columns)
assert pickle.loads(pickle.dumps(structure_factory)) == structure_factory

structure_factory2 = MSv2StructureFactory(table_factory)
structure_factory2 = MSv2StructureFactory(table_factory, partition_columns)
assert structure_factory() is structure_factory2()

keys = tuple(k for kv in structure_factory().keys() for k, _ in kv)
assert tuple(sorted(partition_columns)) == keys
127 changes: 83 additions & 44 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import warnings
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, Iterable
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple
from uuid import uuid4

import xarray
Expand Down Expand Up @@ -33,19 +33,11 @@
from xarray_ms.backend.msv2.structure import PartitionKeyT


def table_factory_factory(ms: str, ninstances: int) -> TableFactory:
"""
Ensures consistency when creating a TableFactory.
Multiple calls to this method with the same argument will
resolve to the same cached instance.
"""
return TableFactory(
Table.from_filename,
ms,
ninstances=ninstances,
readonly=True,
lockoptions="nolock",
)
DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]


def promote_chunks(
Expand Down Expand Up @@ -78,20 +70,62 @@ def promote_chunks(
return return_chunks


def initialise_default_args(
ms: str,
ninstances: int,
auto_corrs: bool,
epoch: str | None,
table_factory: TableFactory | None,
partition_columns: List[str] | None,
partition_key: PartitionKeyT | None,
structure_factory: MSv2StructureFactory | None,
) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]:
"""
Ensures consistency when initialising default arguments from multiple locations
"""
if not os.path.exists(ms):
raise ValueError(f"MS {ms} does not exist")

table_factory = table_factory or TableFactory(
Table.from_filename,
ms,
ninstances=ninstances,
readonly=True,
lockoptions="nolock",
)
epoch = epoch or uuid4().hex[:8]
partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, auto_corrs=auto_corrs
)
structure = structure_factory()
if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return epoch, table_factory, partition_columns, partition_key, structure_factory


class MSv2Store(AbstractWritableDataStore):
"""Store for reading and writing MSv2 data"""

__slots__ = (
"_table_factory",
"_structure_factory",
"_partition",
"_partition_columns",
"_partition_key",
"_auto_corrs",
"_ninstances",
"_epoch",
)

_table_factory: TableFactory
_structure_factory: MSv2StructureFactory
_partition_columns: List[str]
_partition: PartitionKeyT
_autocorrs: bool
_ninstances: int
Expand All @@ -101,14 +135,16 @@ def __init__(
self,
table_factory: TableFactory,
structure_factory: MSv2StructureFactory,
partition: PartitionKeyT,
partition_columns: List[str],
partition_key: PartitionKeyT,
auto_corrs: bool,
ninstances: int,
epoch: str,
):
self._table_factory = table_factory
self._structure_factory = structure_factory
self._partition = partition
self._partition_columns = partition_columns
self._partition_key = partition_key
self._auto_corrs = auto_corrs
self._ninstances = ninstances
self._epoch = epoch
Expand All @@ -118,7 +154,8 @@ def open(
cls,
ms: str,
drop_variables=None,
partition: PartitionKeyT | None = None,
partition_columns: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
auto_corrs: bool = True,
ninstances: int = 1,
epoch: str | None = None,
Expand All @@ -127,23 +164,24 @@ def open(
if not isinstance(ms, str):
raise ValueError("Measurement Sets paths must be strings")

table_factory = table_factory_factory(ms, ninstances)
epoch = epoch or uuid4().hex[:8]
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, auto_corrs
epoch, table_factory, partition_columns, partition_key, structure_factory = (
initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_columns,
partition_key,
structure_factory,
)
)
structure = structure_factory()

if partition is None:
partition = next(iter(structure.keys()))
warnings.warn(f"No partition was supplied. Selected first partition {partition}")
elif partition not in structure:
raise ValueError(f"{partition} not in {list(structure.keys())}")

return cls(
table_factory,
structure_factory,
partition=partition,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -154,12 +192,12 @@ def close(self, **kwargs):

def get_variables(self):
return MainDatasetFactory(
self._partition, self._table_factory, self._structure_factory
self._partition_key, self._table_factory, self._structure_factory
).get_variables()

def get_attrs(self):
try:
ddid = next(iter(v for k, v in self._partition if k == "DATA_DESC_ID"))
ddid = next(iter(v for k, v in self._partition_key if k == "DATA_DESC_ID"))
except StopIteration:
raise KeyError("DATA_DESC_ID not found in partition")

Expand All @@ -183,7 +221,7 @@ def get_encoding(self):
class MSv2PartitionEntryPoint(BackendEntrypoint):
open_dataset_parameters = [
"filename_or_obj",
"partition",
"partition_columns" "partition_key",
"auto_corrs",
"ninstances",
"epoch",
Expand Down Expand Up @@ -212,14 +250,11 @@ def guess_can_open(

def open_dataset(
self,
filename_or_obj: str
| os.PathLike[Any]
| BufferedIOBase
| AbstractDataStore
| TableFactory,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition=None,
partition_columns=None,
partition_key=None,
auto_corrs=True,
ninstances=8,
epoch=None,
Expand All @@ -229,7 +264,8 @@ def open_dataset(
store = MSv2Store.open(
filename_or_obj,
drop_variables=drop_variables,
partition=partition,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -243,6 +279,7 @@ def open_datatree(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition_columns=None,
auto_corrs=True,
ninstances=8,
epoch=None,
Expand All @@ -255,10 +292,11 @@ def open_datatree(
else:
raise ValueError("Measurement Set paths must be strings")

table_factory = table_factory_factory(ms, ninstances)
structure_factory = MSv2StructureFactory(table_factory, auto_corrs=auto_corrs)
structure = structure_factory()
epoch, _, partition_columns, _, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None
)

structure = structure_factory()
datasets = {}
chunks = kwargs.pop("chunks", None)
pchunks = promote_chunks(structure, chunks)
Expand All @@ -267,7 +305,8 @@ def open_datatree(
ds = xarray.open_dataset(
ms,
drop_variables=drop_variables,
partition=partition_key,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand Down
Loading

0 comments on commit cff65e3

Please sign in to comment.