1- from collections import defaultdict
2- from collections .abc import Iterator
1+ import logging
2+ from collections .abc import Iterable
33from enum import StrEnum
44from functools import cache
55from pathlib import Path
66
7- import pandas as pd
8- from lir .data .data_strategies import DataStrategy
9- from lir .data .models import FeatureData
7+ from lir .data .io import search_path
8+ from lir .data .models import DataStrategy , FeatureData
9+ from lir .datasets .feature_data_csv import ExtraField , FeatureDataCsvParser
10+
11+ LOG = logging .getLogger (__name__ )
1012
1113
1214class TestTrainSplit (StrEnum ):
@@ -15,16 +17,29 @@ class TestTrainSplit(StrEnum):
1517 TEST = "v"
1618
1719
18- class ScratchData (DataStrategy ):
19- def __init__ (self , input_file_path : Path ):
20+ SPLIT_COLUMNS = ["split1" , "split2" , "split3" ]
21+
22+
23+ class ScratchCsvReader (FeatureDataCsvParser ):
24+ def __init__ (self , input_file_path : Path | str ):
2025 """Read and represent Scratch specific input data, as corresponding instances.
2126
2227 The data might include n-fold cross validation splits, where each fold has a train/test split.
2328 This class provides access to iterate over the available folds and the corresponding train and test splits.
2429 """
25- self .file_path = input_file_path
30+ super ().__init__ (
31+ source_id_column = ["weapon1" , "weapon2" ],
32+ label_column = "hypothesis" ,
33+ extra_fields = [
34+ ExtraField ("split" , SPLIT_COLUMNS , str ),
35+ ],
36+ message_prefix = f"{ input_file_path } : " ,
37+ )
38+
39+ self .file_path = Path (input_file_path )
2640
27- def _read_instances_from_file (self ):
41+ @cache
42+ def get_instances (self ) -> FeatureData :
2843 """Read K-fold cross validation CSV input data to a list of K corresponding subsets of test/train folds.
2944
3045 In the CSV file, subsets of the data are indicated by the "split<N>" column. For example, 3-fold cross
@@ -36,58 +51,18 @@ def _read_instances_from_file(self):
3651 The remaining columns are treated as features. This means that the pipeline in which this data is used
3752 should filter out any non-relevant feature columns before training or evaluating a model.
3853 """
39- df = pd .read_csv (self .file_path )
40-
41- # Ensure all expected columns are present
42- expected_columns = ["weapon1" , "weapon2" , "hypothesis" , "split1" ]
43- if not all (column in df .columns for column in expected_columns ):
44- raise ValueError (
45- f"Missing one of the expected columns: { ', ' .join (set (expected_columns ) - set (df .columns ))} "
46- )
47-
48- # Find all columns regarding the prepared folds, named 'split*' ('split1', 'split2', etc.)
49- fold_column_names = [c for c in df .columns if c .startswith ("split" )]
54+ path = search_path (self .file_path )
55+ LOG .debug (f"parsing CSV file: { self .file_path } as { path } " )
56+ with open (path ) as f :
57+ return self ._parse_file (f )
5058
51- # Feature columns are all columns that are not the expected columns
52- feature_columns = [c for c in df .columns if c not in expected_columns and c not in fold_column_names ]
5359
54- label_column = ["hypothesis" ]
55-
56- # Group the folds by the column name, i.e. 'split1', 'split2', etc.
57- df_with_subsets = df .melt (
58- id_vars = label_column + feature_columns ,
59- value_vars = fold_column_names ,
60- var_name = "subset" ,
61- value_name = "test_train_split" ,
62- )
63-
64- subsets = []
65-
66- # Loop over each subset
67- for _ , folds in df_with_subsets .groupby ("subset" ):
68- # Filter out the data marked as "not used"
69- test_train_folds = folds [folds .test_train_split != TestTrainSplit .NOT_USED ]
70-
71- # Loop over 'train' / 'test' folds for the current subset
72- subset_folds = defaultdict ()
73-
74- for test_or_train_indicator , raw_data in test_train_folds .groupby ("test_train_split" ):
75- # The `test_or_train_indicator` refers to the role of this data
76- # in the current fold; belonging to either the 'test' or 'train' split.
77- features = raw_data [feature_columns ].to_numpy (dtype = float ).reshape (- 1 , len (feature_columns ))
78- labels = raw_data [label_column ].to_numpy (dtype = int ).flatten ()
79-
80- subset_folds [test_or_train_indicator ] = FeatureData (features = features , labels = labels )
81-
82- subsets .append ((subset_folds [TestTrainSplit .TRAIN ], subset_folds [TestTrainSplit .TEST ]))
83-
84- return subsets
85-
86- @cache
87- def _get_instances (self ):
88- """Read instances from file only once."""
89- return self ._read_instances_from_file ()
60+ class PredefinedCrossValidation (DataStrategy ):
61+ """Return a series of train/test sets for a predefined cross-validation setup."""
9062
91- def __iter__ (self ) -> Iterator [tuple [FeatureData , FeatureData ]]:
92- """Allow iteration by looping over the resulting train/test split(s)."""
93- yield from self ._get_instances ()
63+ def apply (self , instances : FeatureData ) -> Iterable [tuple [FeatureData , FeatureData ]]:
64+ """Return a series of train/test sets for a predefined cross-validation setup."""
65+ for split in range (len (SPLIT_COLUMNS )):
66+ training_data = instances [instances .split [:, split ] == TestTrainSplit .TRAIN .value ] # type: ignore
67+ test_data = instances [instances .split [:, split ] == TestTrainSplit .TEST .value ] # type: ignore
68+ yield training_data , test_data
0 commit comments