Skip to content

Commit eedf304

Browse files
committed
update to lir 1.3.3
1 parent bea151b commit eedf304

9 files changed

Lines changed: 470 additions & 493 deletions

File tree

lrmodule/input_data.py

Lines changed: 36 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from collections import defaultdict
2-
from collections.abc import Iterator
1+
import logging
2+
from collections.abc import Iterable
33
from enum import StrEnum
44
from functools import cache
55
from pathlib import Path
66

7-
import pandas as pd
8-
from lir.data.data_strategies import DataStrategy
9-
from lir.data.models import FeatureData
7+
from lir.data.io import search_path
8+
from lir.data.models import DataStrategy, FeatureData
9+
from lir.datasets.feature_data_csv import ExtraField, FeatureDataCsvParser
10+
11+
LOG = logging.getLogger(__name__)
1012

1113

1214
class TestTrainSplit(StrEnum):
@@ -15,16 +17,29 @@ class TestTrainSplit(StrEnum):
1517
TEST = "v"
1618

1719

18-
class ScratchData(DataStrategy):
19-
def __init__(self, input_file_path: Path):
20+
SPLIT_COLUMNS = ["split1", "split2", "split3"]
21+
22+
23+
class ScratchCsvReader(FeatureDataCsvParser):
24+
def __init__(self, input_file_path: Path | str):
2025
"""Read and represent Scratch specific input data, as corresponding instances.
2126
2227
The data might include n-fold cross validation splits, where each fold has a train/test split.
2328
This class provides access to iterate over the available folds and the corresponding train and test splits.
2429
"""
25-
self.file_path = input_file_path
30+
super().__init__(
31+
source_id_column=["weapon1", "weapon2"],
32+
label_column="hypothesis",
33+
extra_fields=[
34+
ExtraField("split", SPLIT_COLUMNS, str),
35+
],
36+
message_prefix=f"{input_file_path}: ",
37+
)
38+
39+
self.file_path = Path(input_file_path)
2640

27-
def _read_instances_from_file(self):
41+
@cache
42+
def get_instances(self) -> FeatureData:
2843
"""Read K-fold cross validation CSV input data to a list of K corresponding subsets of test/train folds.
2944
3045
In the CSV file, subsets of the data are indicated by the "split<N>" column. For example, 3-fold cross
@@ -36,58 +51,18 @@ def _read_instances_from_file(self):
3651
The remaining columns are treated as features. This means that the pipeline in which this data is used
3752
should filter out any non-relevant feature columns before training or evaluating a model.
3853
"""
39-
df = pd.read_csv(self.file_path)
40-
41-
# Ensure all expected columns are present
42-
expected_columns = ["weapon1", "weapon2", "hypothesis", "split1"]
43-
if not all(column in df.columns for column in expected_columns):
44-
raise ValueError(
45-
f"Missing one of the expected columns: {', '.join(set(expected_columns) - set(df.columns))}"
46-
)
47-
48-
# Find all columns regarding the prepared folds, named 'split*' ('split1', 'split2', etc.)
49-
fold_column_names = [c for c in df.columns if c.startswith("split")]
54+
path = search_path(self.file_path)
55+
LOG.debug(f"parsing CSV file: {self.file_path} as {path}")
56+
with open(path) as f:
57+
return self._parse_file(f)
5058

51-
# Feature columns are all columns that are not the expected columns
52-
feature_columns = [c for c in df.columns if c not in expected_columns and c not in fold_column_names]
5359

54-
label_column = ["hypothesis"]
55-
56-
# Group the folds by the column name, i.e. 'split1', 'split2', etc.
57-
df_with_subsets = df.melt(
58-
id_vars=label_column + feature_columns,
59-
value_vars=fold_column_names,
60-
var_name="subset",
61-
value_name="test_train_split",
62-
)
63-
64-
subsets = []
65-
66-
# Loop over each subset
67-
for _, folds in df_with_subsets.groupby("subset"):
68-
# Filter out the data marked as "not used"
69-
test_train_folds = folds[folds.test_train_split != TestTrainSplit.NOT_USED]
70-
71-
# Loop over 'train' / 'test' folds for the current subset
72-
subset_folds = defaultdict()
73-
74-
for test_or_train_indicator, raw_data in test_train_folds.groupby("test_train_split"):
75-
# The `test_or_train_indicator` refers to the role of this data
76-
# in the current fold; belonging to either the 'test' or 'train' split.
77-
features = raw_data[feature_columns].to_numpy(dtype=float).reshape(-1, len(feature_columns))
78-
labels = raw_data[label_column].to_numpy(dtype=int).flatten()
79-
80-
subset_folds[test_or_train_indicator] = FeatureData(features=features, labels=labels)
81-
82-
subsets.append((subset_folds[TestTrainSplit.TRAIN], subset_folds[TestTrainSplit.TEST]))
83-
84-
return subsets
85-
86-
@cache
87-
def _get_instances(self):
88-
"""Read instances from file only once."""
89-
return self._read_instances_from_file()
60+
class PredefinedCrossValidation(DataStrategy):
61+
"""Return a series of train/test sets for a predefined cross-validation setup."""
9062

91-
def __iter__(self) -> Iterator[tuple[FeatureData, FeatureData]]:
92-
"""Allow iteration by looping over the resulting train/test split(s)."""
93-
yield from self._get_instances()
63+
def apply(self, instances: FeatureData) -> Iterable[tuple[FeatureData, FeatureData]]:
64+
"""Return a series of train/test sets for a predefined cross-validation setup."""
65+
for split in range(len(SPLIT_COLUMNS)):
66+
training_data = instances[instances.split[:, split] == TestTrainSplit.TRAIN.value] # type: ignore
67+
test_data = instances[instances.split[:, split] == TestTrainSplit.TEST.value] # type: ignore
68+
yield training_data, test_data

lrmodule/mcmc.py

Lines changed: 0 additions & 203 deletions
This file was deleted.

0 commit comments

Comments
 (0)