Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lrmodule/copy_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class CopyCSV(Aggregation):
def __init__(self, source_file: str, target_dir: str, columns: list[str], new_file_name: str | None):
self.source_file = search_path(Path(source_file))
self.target_dir = Path(target_dir)

# Ensure the target directory exists or create it
self.target_dir.mkdir(parents=True, exist_ok=True)

self.columns = columns
if new_file_name is None:
self.new_file_name = self.target_dir / self.source_file.name
Expand Down
15 changes: 0 additions & 15 deletions lrmodule/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ def transform_marktype_ccf(original_score):
return transformed_score.reshape(-1, 1)


def select_marktype_accf(all_features: np.ndarray) -> np.ndarray:
"""Select the 'accf' score from the input features."""
return all_features[:, 0]


def transform_marktype_accf(original_score):
"""Transform the 'accf' score using a log transformation.

Expand All @@ -32,18 +27,8 @@ def transform_marktype_accf(original_score):
def transform_marktype_rel_cmc(original_score: np.ndarray) -> np.ndarray:
"""Transform the 'rel_cmc' score by calculating the ratio between the two columns.

When used, select_marktype_cmc should be used first to select the relevant columns.
Currently not used, but equivalent to the Matlab implementation.
"""
cmc = original_score[:, 0]
n = original_score[:, 1]
return cmc / n


def select_marktype_cmc(all_features: np.ndarray) -> np.ndarray:
"""Select 'cmc' and 'n' from the input features."""
# The 'cmc' is expected to be in column 1 and 'n' in column 2.
relevant_features = all_features[:, 1:3]
if np.any(relevant_features[:, 1] - relevant_features[:, 0] < 0):
raise ValueError("n must be larger than or equal to cmc")
return relevant_features
51 changes: 4 additions & 47 deletions lrmodule/input_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import logging
from collections.abc import Iterable
from enum import StrEnum
from functools import cache
from pathlib import Path

from lir.data.io import search_path
from lir.data.models import DataStrategy, FeatureData
from lir.datasets.feature_data_csv import ExtraField, FeatureDataCsvParser

LOG = logging.getLogger(__name__)

Expand All @@ -17,52 +13,13 @@ class TestTrainSplit(StrEnum):
TEST = "v"


SPLIT_COLUMNS = ["split1", "split2", "split3"]


class ScratchCsvReader(FeatureDataCsvParser):
def __init__(self, input_file_path: Path | str):
"""Read and represent Scratch specific input data, as corresponding instances.

The data might include n-fold cross validation splits, where each fold has a train/test split.
This class provides access to iterate over the available folds and the corresponding train and test splits.
"""
super().__init__(
source_id_column=["weapon1", "weapon2"],
label_column="hypothesis",
extra_fields=[
ExtraField("split", SPLIT_COLUMNS, str),
],
message_prefix=f"{input_file_path}: ",
)

self.file_path = Path(input_file_path)

@cache
def get_instances(self) -> FeatureData:
"""Read K-fold cross validation CSV input data to a list of K corresponding subsets of test/train folds.

In the CSV file, subsets of the data are indicated by the "split<N>" column. For example, 3-fold cross
validation is represented through columns 'split1', 'split2' and 'split3' which indicate if the data in
this subset belongs to the test split ("t"), train split ("v") or is not used ("n").

The FeatureData instances are created from all columns that are not part of the expected columns:
'weapon1', 'weapon2', 'hypothesis', 'split<N>' (for all N). The 'hypothesis' column is used as label.
The remaining columns are treated as features. This means that the pipeline in which this data is used
should filter out any non-relevant feature columns before training or evaluating a model.
"""
path = search_path(self.file_path)
LOG.debug(f"parsing CSV file: {self.file_path} as {path}")
with open(path) as f:
return self._parse_file(f)


class PredefinedCrossValidation(DataStrategy):
"""Return a series of train/test sets for a predefined cross-validation setup."""

def apply(self, instances: FeatureData) -> Iterable[tuple[FeatureData, FeatureData]]:
"""Return a series of train/test sets for a predefined cross-validation setup."""
for split in range(len(SPLIT_COLUMNS)):
training_data = instances[instances.split[:, split] == TestTrainSplit.TRAIN.value] # type: ignore
test_data = instances[instances.split[:, split] == TestTrainSplit.TEST.value] # type: ignore
role_assignments = instances.split # type: ignore
for split in range(role_assignments.shape[1]):
training_data = instances[role_assignments[:, split] == TestTrainSplit.TRAIN.value]
test_data = instances[role_assignments[:, split] == TestTrainSplit.TEST.value]
yield training_data, test_data
15 changes: 13 additions & 2 deletions models/aperture_shear/validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,19 @@ experiments:
lr_system: *aperture_shear_ccf_lrsystem
data:
provider:
method: lrmodule.input_data.ScratchCsvReader
input_file_path: ${input_file_path}
method: parse_features_from_csv_file
file: ${input_file_path}
label_column: hypothesis
source_id_column:
- weapon1
- weapon2
extra_fields:
- name: split
columns:
- split1
- split2
- split3
cell_type: str

splits:
strategy: lrmodule.input_data.PredefinedCrossValidation
Expand Down
19 changes: 15 additions & 4 deletions models/breech_face_impression/validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ lrsystem: &lrsystem
architecture: lrmodule.binary_lrsystem
modules:
steps:
select:
method: lrmodule.helpers.select_marktype_cmc
mcmc:
method: mcmc
# parameters copied from MCMC test
Expand Down Expand Up @@ -37,8 +35,21 @@ experiments:
lr_system: *lrsystem
data:
provider:
method: lrmodule.input_data.ScratchCsvReader
input_file_path: ${input_file_path}
method: parse_features_from_csv_file
file: ${input_file_path}
label_column: hypothesis
source_id_column:
- weapon1
- weapon2
extra_fields:
- name: split
columns:
- split1
- split2
- split3
cell_type: str
ignore_columns:
- accf
splits:
strategy: lrmodule.input_data.PredefinedCrossValidation
output:
Expand Down
19 changes: 15 additions & 4 deletions models/firing_pin_impression/validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ lrsystem: &lrsystem
architecture: lrmodule.binary_lrsystem
modules:
steps:
select:
method: lrmodule.helpers.select_marktype_cmc
mcmc:
method: mcmc
# parameters copied from MCMC test
Expand Down Expand Up @@ -37,8 +35,21 @@ experiments:
lr_system: *lrsystem
data:
provider:
method: lrmodule.input_data.ScratchCsvReader
input_file_path: ${input_file_path}
method: parse_features_from_csv_file
file: ${input_file_path}
label_column: hypothesis
source_id_column:
- weapon1
- weapon2
extra_fields:
- name: split
columns:
- split1
- split2
- split3
cell_type: str
ignore_columns:
- accf
splits:
strategy: lrmodule.input_data.PredefinedCrossValidation
output:
Expand Down
19 changes: 14 additions & 5 deletions tests/test_input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

import numpy as np
from lir.data.models import FeatureData
from lir.datasets.feature_data_csv import ExtraField, FeatureDataCsvFileParser
from numpy import array

from lrmodule.input_data import ScratchCsvReader, PredefinedCrossValidation
from lrmodule.input_data import PredefinedCrossValidation


def test_input_data_to_instances():
"""Check that input data is correctly parsed to instances (having multiple folds)."""
# Arrange
input_file = Path(__file__).parent / "fixtures/input_data/train_test_data.csv"
dataset = ScratchCsvReader(input_file).get_instances()
dataset = FeatureDataCsvFileParser(
file=input_file,
label_column="hypothesis",
source_id_column=["weapon1", "weapon2"],
extra_fields=[ExtraField("split", ["split1", "split2", "split3"], str)],
).get_instances()
strategy = PredefinedCrossValidation()

# The following train/test splits for the given data_subsets are expected
Expand All @@ -34,15 +40,18 @@ def test_input_data_to_instances():
]

# Act
assert dataset.split.shape == (5, 3), "role assignment shape should match the input data"
assert np.all(dataset.split[:, 0] == np.array(['t', 'v', 'v', 'n', 't']))
split = getattr(dataset, "split")
assert split.shape == (5, 3), "role assignment shape should match the input data"
assert np.all(split[:, 0] == np.array(["t", "v", "v", "n", "t"]))

data_subsets = list(strategy.apply(dataset))

# Assert
# The fixture contains 3 subsets of data (3-fold cross validation)
assert len(data_subsets) == 3 # noqa: PLR2004 (magic number)

for i, ((actual_train, actual_test), (expected_train, expected_test)) in enumerate(zip(data_subsets, [subset_1, subset_2, subset_3])):
for i, ((actual_train, actual_test), (expected_train, expected_test)) in enumerate(
zip(data_subsets, [subset_1, subset_2, subset_3])
):
assert FeatureData(features=actual_train.features, labels=actual_train.labels) == expected_train
assert FeatureData(features=actual_test.features, labels=actual_test.labels) == expected_test
Loading