Skip to content

Commit e68ead9

Browse files
PimMeulensteenPim Meulensteen (DBS)wowtor
authored
Cleanup read_csv method / fix bug in copy-csv / remove hard-coded split column header (#29)
* fix bug in copy-csv * remove unused methods and update validation YAML to use new parsing method * use instances.split directly for split handling * remove hard-coded split column header --------- Co-authored-by: Pim Meulensteen (DBS) <pim.meulensteen@nfi.nl> Co-authored-by: wabos <20113294+wowtor@users.noreply.github.com>
1 parent 9fcde3d commit e68ead9

7 files changed

Lines changed: 65 additions & 77 deletions

File tree

lrmodule/copy_csv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class CopyCSV(Aggregation):
2222
def __init__(self, source_file: str, target_dir: str, columns: list[str], new_file_name: str | None):
2323
self.source_file = search_path(Path(source_file))
2424
self.target_dir = Path(target_dir)
25+
26+
# Ensure the target directory exists or create it
27+
self.target_dir.mkdir(parents=True, exist_ok=True)
28+
2529
self.columns = columns
2630
if new_file_name is None:
2731
self.new_file_name = self.target_dir / self.source_file.name

lrmodule/helpers.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ def transform_marktype_ccf(original_score):
1212
return transformed_score.reshape(-1, 1)
1313

1414

15-
def select_marktype_accf(all_features: np.ndarray) -> np.ndarray:
16-
"""Select the 'accf' score from the input features."""
17-
return all_features[:, 0]
18-
19-
2015
def transform_marktype_accf(original_score):
2116
"""Transform the 'accf' score using a log transformation.
2217
@@ -32,18 +27,8 @@ def transform_marktype_accf(original_score):
3227
def transform_marktype_rel_cmc(original_score: np.ndarray) -> np.ndarray:
3328
"""Transform the 'rel_cmc' score by calculating the ratio between the two columns.
3429
35-
When used, select_marktype_cmc should be used first to select the relevant columns.
3630
Currently not used, but equivalent to the Matlab implementation.
3731
"""
3832
cmc = original_score[:, 0]
3933
n = original_score[:, 1]
4034
return cmc / n
41-
42-
43-
def select_marktype_cmc(all_features: np.ndarray) -> np.ndarray:
44-
"""Select 'cmc' and 'n' from the input features."""
45-
# The 'cmc' is expected to be in column 1 and 'n' in column 2.
46-
relevant_features = all_features[:, 1:3]
47-
if np.any(relevant_features[:, 1] - relevant_features[:, 0] < 0):
48-
raise ValueError("n must be larger than or equal to cmc")
49-
return relevant_features

lrmodule/input_data.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import logging
22
from collections.abc import Iterable
33
from enum import StrEnum
4-
from functools import cache
5-
from pathlib import Path
64

7-
from lir.data.io import search_path
85
from lir.data.models import DataStrategy, FeatureData
9-
from lir.datasets.feature_data_csv import ExtraField, FeatureDataCsvParser
106

117
LOG = logging.getLogger(__name__)
128

@@ -17,52 +13,13 @@ class TestTrainSplit(StrEnum):
1713
TEST = "v"
1814

1915

20-
SPLIT_COLUMNS = ["split1", "split2", "split3"]
21-
22-
23-
class ScratchCsvReader(FeatureDataCsvParser):
24-
def __init__(self, input_file_path: Path | str):
25-
"""Read and represent Scratch specific input data, as corresponding instances.
26-
27-
The data might include n-fold cross validation splits, where each fold has a train/test split.
28-
This class provides access to iterate over the available folds and the corresponding train and test splits.
29-
"""
30-
super().__init__(
31-
source_id_column=["weapon1", "weapon2"],
32-
label_column="hypothesis",
33-
extra_fields=[
34-
ExtraField("split", SPLIT_COLUMNS, str),
35-
],
36-
message_prefix=f"{input_file_path}: ",
37-
)
38-
39-
self.file_path = Path(input_file_path)
40-
41-
@cache
42-
def get_instances(self) -> FeatureData:
43-
"""Read K-fold cross validation CSV input data to a list of K corresponding subsets of test/train folds.
44-
45-
In the CSV file, subsets of the data are indicated by the "split<N>" column. For example, 3-fold cross
46-
validation is represented through columns 'split1', 'split2' and 'split3' which indicate if the data in
47-
this subset belongs to the test split ("t"), train split ("v") or is not used ("n").
48-
49-
The FeatureData instances are created from all columns that are not part of the expected columns:
50-
'weapon1', 'weapon2', 'hypothesis', 'split<N>' (for all N). The 'hypothesis' column is used as label.
51-
The remaining columns are treated as features. This means that the pipeline in which this data is used
52-
should filter out any non-relevant feature columns before training or evaluating a model.
53-
"""
54-
path = search_path(self.file_path)
55-
LOG.debug(f"parsing CSV file: {self.file_path} as {path}")
56-
with open(path) as f:
57-
return self._parse_file(f)
58-
59-
6016
class PredefinedCrossValidation(DataStrategy):
6117
"""Return a series of train/test sets for a predefined cross-validation setup."""
6218

6319
def apply(self, instances: FeatureData) -> Iterable[tuple[FeatureData, FeatureData]]:
6420
"""Return a series of train/test sets for a predefined cross-validation setup."""
65-
for split in range(len(SPLIT_COLUMNS)):
66-
training_data = instances[instances.split[:, split] == TestTrainSplit.TRAIN.value] # type: ignore
67-
test_data = instances[instances.split[:, split] == TestTrainSplit.TEST.value] # type: ignore
21+
role_assignments = instances.split # type: ignore
22+
for split in range(role_assignments.shape[1]):
23+
training_data = instances[role_assignments[:, split] == TestTrainSplit.TRAIN.value]
24+
test_data = instances[role_assignments[:, split] == TestTrainSplit.TEST.value]
6825
yield training_data, test_data

lrmodule/models/aperture_shear/validation.yaml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,19 @@ experiments:
2727
lr_system: *aperture_shear_ccf_lrsystem
2828
data:
2929
provider:
30-
method: lrmodule.input_data.ScratchCsvReader
31-
input_file_path: ${input_file_path}
30+
method: parse_features_from_csv_file
31+
file: ${input_file_path}
32+
label_column: hypothesis
33+
source_id_column:
34+
- weapon1
35+
- weapon2
36+
extra_fields:
37+
- name: split
38+
columns:
39+
- split1
40+
- split2
41+
- split3
42+
cell_type: str
3243

3344
splits:
3445
strategy: lrmodule.input_data.PredefinedCrossValidation

lrmodule/models/breech_face_impression/validation.yaml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ lrsystem: &lrsystem
55
architecture: lrmodule.binary_lrsystem
66
modules:
77
steps:
8-
select:
9-
method: lrmodule.helpers.select_marktype_cmc
108
mcmc:
119
method: mcmc
1210
# parameters copied from MCMC test
@@ -37,8 +35,21 @@ experiments:
3735
lr_system: *lrsystem
3836
data:
3937
provider:
40-
method: lrmodule.input_data.ScratchCsvReader
41-
input_file_path: ${input_file_path}
38+
method: parse_features_from_csv_file
39+
file: ${input_file_path}
40+
label_column: hypothesis
41+
source_id_column:
42+
- weapon1
43+
- weapon2
44+
extra_fields:
45+
- name: split
46+
columns:
47+
- split1
48+
- split2
49+
- split3
50+
cell_type: str
51+
ignore_columns:
52+
- accf
4253
splits:
4354
strategy: lrmodule.input_data.PredefinedCrossValidation
4455
output:

lrmodule/models/firing_pin_impression/validation.yaml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ lrsystem: &lrsystem
55
architecture: lrmodule.binary_lrsystem
66
modules:
77
steps:
8-
select:
9-
method: lrmodule.helpers.select_marktype_cmc
108
mcmc:
119
method: mcmc
1210
# parameters copied from MCMC test
@@ -37,8 +35,21 @@ experiments:
3735
lr_system: *lrsystem
3836
data:
3937
provider:
40-
method: lrmodule.input_data.ScratchCsvReader
41-
input_file_path: ${input_file_path}
38+
method: parse_features_from_csv_file
39+
file: ${input_file_path}
40+
label_column: hypothesis
41+
source_id_column:
42+
- weapon1
43+
- weapon2
44+
extra_fields:
45+
- name: split
46+
columns:
47+
- split1
48+
- split2
49+
- split3
50+
cell_type: str
51+
ignore_columns:
52+
- accf
4253
splits:
4354
strategy: lrmodule.input_data.PredefinedCrossValidation
4455
output:

tests/test_input_data.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22

33
import numpy as np
44
from lir.data.models import FeatureData
5+
from lir.datasets.feature_data_csv import ExtraField, FeatureDataCsvFileParser
56
from numpy import array
67

7-
from lrmodule.input_data import ScratchCsvReader, PredefinedCrossValidation
8+
from lrmodule.input_data import PredefinedCrossValidation
89

910

1011
def test_input_data_to_instances():
1112
"""Check that input data is correctly parsed to instances (having multiple folds)."""
1213
# Arrange
1314
input_file = Path(__file__).parent / "fixtures/input_data/train_test_data.csv"
14-
dataset = ScratchCsvReader(input_file).get_instances()
15+
dataset = FeatureDataCsvFileParser(
16+
file=input_file,
17+
label_column="hypothesis",
18+
source_id_column=["weapon1", "weapon2"],
19+
extra_fields=[ExtraField("split", ["split1", "split2", "split3"], str)],
20+
).get_instances()
1521
strategy = PredefinedCrossValidation()
1622

1723
# The following train/test splits for the given data_subsets are expected
@@ -34,15 +40,18 @@ def test_input_data_to_instances():
3440
]
3541

3642
# Act
37-
assert dataset.split.shape == (5, 3), "role assignment shape should match the input data"
38-
assert np.all(dataset.split[:, 0] == np.array(['t', 'v', 'v', 'n', 't']))
43+
split = getattr(dataset, "split")
44+
assert split.shape == (5, 3), "role assignment shape should match the input data"
45+
assert np.all(split[:, 0] == np.array(["t", "v", "v", "n", "t"]))
3946

4047
data_subsets = list(strategy.apply(dataset))
4148

4249
# Assert
4350
# The fixture contains 3 subsets of data (3-fold cross validation)
4451
assert len(data_subsets) == 3 # noqa: PLR2004 (magic number)
4552

46-
for i, ((actual_train, actual_test), (expected_train, expected_test)) in enumerate(zip(data_subsets, [subset_1, subset_2, subset_3])):
53+
for i, ((actual_train, actual_test), (expected_train, expected_test)) in enumerate(
54+
zip(data_subsets, [subset_1, subset_2, subset_3])
55+
):
4756
assert FeatureData(features=actual_train.features, labels=actual_train.labels) == expected_train
4857
assert FeatureData(features=actual_test.features, labels=actual_test.labels) == expected_test

0 commit comments

Comments
 (0)