11import logging
22from collections .abc import Iterable
33from enum import StrEnum
4- from functools import cache
5- from pathlib import Path
64
7- from lir .data .io import search_path
85from lir .data .models import DataStrategy , FeatureData
9- from lir .datasets .feature_data_csv import ExtraField , FeatureDataCsvParser
106
117LOG = logging .getLogger (__name__ )
128
@@ -17,52 +13,13 @@ class TestTrainSplit(StrEnum):
1713 TEST = "v"
1814
1915
20- SPLIT_COLUMNS = ["split1" , "split2" , "split3" ]
21-
22-
23- class ScratchCsvReader (FeatureDataCsvParser ):
24- def __init__ (self , input_file_path : Path | str ):
25- """Read and represent Scratch specific input data, as corresponding instances.
26-
27- The data might include n-fold cross validation splits, where each fold has a train/test split.
28- This class provides access to iterate over the available folds and the corresponding train and test splits.
29- """
30- super ().__init__ (
31- source_id_column = ["weapon1" , "weapon2" ],
32- label_column = "hypothesis" ,
33- extra_fields = [
34- ExtraField ("split" , SPLIT_COLUMNS , str ),
35- ],
36- message_prefix = f"{ input_file_path } : " ,
37- )
38-
39- self .file_path = Path (input_file_path )
40-
41- @cache
42- def get_instances (self ) -> FeatureData :
43- """Read K-fold cross validation CSV input data to a list of K corresponding subsets of test/train folds.
44-
45- In the CSV file, subsets of the data are indicated by the "split<N>" column. For example, 3-fold cross
46- validation is represented through columns 'split1', 'split2' and 'split3' which indicate if the data in
47- this subset belongs to the test split ("t"), train split ("v") or is not used ("n").
48-
49- The FeatureData instances are created from all columns that are not part of the expected columns:
50- 'weapon1', 'weapon2', 'hypothesis', 'split<N>' (for all N). The 'hypothesis' column is used as label.
51- The remaining columns are treated as features. This means that the pipeline in which this data is used
52- should filter out any non-relevant feature columns before training or evaluating a model.
53- """
54- path = search_path (self .file_path )
55- LOG .debug (f"parsing CSV file: { self .file_path } as { path } " )
56- with open (path ) as f :
57- return self ._parse_file (f )
58-
59-
6016class PredefinedCrossValidation (DataStrategy ):
6117 """Return a series of train/test sets for a predefined cross-validation setup."""
6218
6319 def apply (self , instances : FeatureData ) -> Iterable [tuple [FeatureData , FeatureData ]]:
6420 """Return a series of train/test sets for a predefined cross-validation setup."""
65- for split in range (len (SPLIT_COLUMNS )):
66- training_data = instances [instances .split [:, split ] == TestTrainSplit .TRAIN .value ] # type: ignore
67- test_data = instances [instances .split [:, split ] == TestTrainSplit .TEST .value ] # type: ignore
21+ role_assignments = instances .split # type: ignore
22+ for split in range (role_assignments .shape [1 ]):
23+ training_data = instances [role_assignments [:, split ] == TestTrainSplit .TRAIN .value ]
24+ test_data = instances [role_assignments [:, split ] == TestTrainSplit .TEST .value ]
6825 yield training_data , test_data
0 commit comments