diff --git a/docs/contributing/index.rst b/docs/contributing/index.rst index 600d05a02..5c0e88455 100644 --- a/docs/contributing/index.rst +++ b/docs/contributing/index.rst @@ -210,7 +210,7 @@ Also test the functions available in the ``Dataset`` class. .. code:: python - > data.get_node_columns(['sources']) + > data.get_node_columns(['sources'], scope='SPRAS') sources NODEID 0 True A diff --git a/spras/allpairs.py b/spras/allpairs.py index 7ff1ade5c..a5465c359 100644 --- a/spras/allpairs.py +++ b/spras/allpairs.py @@ -29,9 +29,7 @@ def generate_inputs(data: Dataset, filename_map): # Get sources and targets for node input file # Borrowed code from pathlinker.py - sources_targets = data.get_node_columns(["sources", "targets"]) - if sources_targets is None: - raise ValueError("All Pairs Shortest Paths requires sources and targets") + sources_targets = data.get_node_columns(["sources", "targets"], "All Pairs Shortest Paths") both_series = sources_targets.sources & sources_targets.targets for _index, row in sources_targets[both_series].iterrows(): diff --git a/spras/btb.py b/spras/btb.py index 71f774858..ed6491c4d 100644 --- a/spras/btb.py +++ b/spras/btb.py @@ -40,19 +40,11 @@ def generate_inputs(data, filename_map): # Get sources and write to file, repeat for targets # Does not check whether a node is a source and a target - for node_type in ['sources', 'targets']: - nodes = data.get_node_columns([node_type]) - if nodes is None: - raise ValueError(f'No {node_type} found in the node files') - + for node_type, nodes in data.get_node_columns_separate(['sources', 'targets'], "BowTieBuilder").items(): # TODO test whether this selection is needed, what values could the column contain that we would want to # include or exclude? nodes = nodes.loc[nodes[node_type]] - if node_type == "sources": - nodes.to_csv(filename_map["sources"], sep= '\t', index=False, columns=['NODEID'], header=False) - elif node_type == "targets": - nodes.to_csv(filename_map["targets"], sep= '\t', index=False, columns=['NODEID'], header=False) - + nodes.to_csv(filename_map[node_type], sep='\t', index=False, columns=['NODEID'], header=False) # Create network file edges = data.get_interactome() diff --git a/spras/dataset.py b/spras/dataset.py index 1346750e3..6e9db500a 100644 --- a/spras/dataset.py +++ b/spras/dataset.py @@ -10,6 +10,50 @@ Methods and intermediate state for loading data and putting it into pandas tables for use by pathway reconstruction algorithms. """ +class MissingDataError(RuntimeError): + """ + Raises when there is missing data from the input dataframe, for `generate_input`. + This is thrown by PRMs. + """ + + scope: str + """ + This is usually the name of the PRM throwing this error. + We generically call this 'scope' + """ + + missing_message: list[str] | str + """ + Either a list of some specific data is missing, or we provide a custom + error message. + + This is in the format: + + (If a string) {Scope} is missing data: {message} + (If a list) {Scope} requires columns {message joined by ", "} + """ + + def process_message(scope: str, missing_message: list[str] | str) -> str: + if isinstance(missing_message, str): + return f"{scope} is missing data: {missing_message}" + else: + return "{} requires columns: {}".format(scope, ", ".join(missing_message)) + + def __init__(self, scope: str, missing_message: list[str] | str): + """ + Constructs a new MissingDataError. + + @param message: The message or missing columns to let the user know about. + See the `MissingDataError#missing_message` docstring for more info + """ + + self.scope = scope + self.missing_message = missing_message + + super(MissingDataError, self).__init__(MissingDataError.process_message(scope, missing_message)) + + def __str__(self): + return MissingDataError.process_message(self.algorithm, self.missing_message) class Dataset: @@ -132,14 +176,23 @@ def load_files_from_dict(self, dataset_dict): self.node_table.insert(0, "NODEID", self.node_table.pop("NODEID")) self.other_files = dataset_dict["other_files"] - def get_node_columns(self, col_names: list[str]) -> pd.DataFrame: + def get_node_columns(self, col_names: list[str], scope: str) -> pd.DataFrame: """ - returns: A table containing the requested column names and node IDs + @param scope: The name of the algorithm (or a more general 'scope' like SPRAS) + to fail on if get_node_columns fails. + @returns: A table containing the requested column names and node IDs for all nodes with at least 1 of the requested values being non-empty """ + # Don't mutate the input col_names + col_names = col_names.copy() + if self.node_table is None: raise ValueError("node_table is None: can't request node columns of an empty dataset.") + needed_columns = set(col_names).difference(self.node_table.columns) + if len(needed_columns) != 0: + raise MissingDataError(scope, needed_columns) + col_names.append(self.NODE_ID) filtered_table = self.node_table[col_names] filtered_table = filtered_table.dropna( @@ -156,6 +209,23 @@ def get_node_columns(self, col_names: list[str]) -> pd.DataFrame: ) return filtered_table + def get_node_columns_separate(self, col_names: list[str], scope: str) -> dict[str, pd.DataFrame]: + """ + Get each `col_name` in `col_names` as a separate call to `get_node_columns`, + allowing better column filtering for NODEIDs + + This is useful for making separate node lists of specific column names. + """ + needed_columns = set(col_names).difference(self.node_table.columns) + if len(needed_columns) != 0: + raise MissingDataError(scope, needed_columns) + + result_dict: dict[str, pd.DataFrame] = dict() + for name in col_names: + result_dict[name] = self.get_node_columns([name], scope) + + return result_dict + def contains_node_columns(self, col_names: list[str] | str): if self.node_table is None: raise ValueError("node_table is None: can't request node columns of an empty dataset.") diff --git a/spras/domino.py b/spras/domino.py index 28a391434..abbb3f005 100644 --- a/spras/domino.py +++ b/spras/domino.py @@ -41,11 +41,7 @@ def generate_inputs(data, filename_map): DOMINO.validate_required_inputs(filename_map) # Get active genes for node input file - if data.contains_node_columns('active'): - # NODEID is always included in the node table - node_df = data.get_node_columns(['active']) - else: - raise ValueError('DOMINO requires active genes') + node_df = data.get_node_columns(['active'], 'DOMINO') node_df = node_df[node_df['active'] == True] # Transform each node id with a prefix diff --git a/spras/meo.py b/spras/meo.py index 3e4ca4d46..80d0fbc18 100644 --- a/spras/meo.py +++ b/spras/meo.py @@ -98,11 +98,7 @@ def generate_inputs(data, filename_map): # Get sources and write to file, repeat for targets # Does not check whether a node is a source and a target - for node_type in ['sources', 'targets']: - nodes = data.get_node_columns([node_type]) - if nodes is None: - raise ValueError(f'No {node_type} found in the node files') - + for node_type, nodes in data.get_node_columns_separate(['sources', 'targets'], "MEO").items(): # TODO test whether this selection is needed, what values could the column contain that we would want to # include or exclude? nodes = nodes.loc[nodes[node_type]] diff --git a/spras/mincostflow.py b/spras/mincostflow.py index f883afb52..0b8bbde7c 100644 --- a/spras/mincostflow.py +++ b/spras/mincostflow.py @@ -40,10 +40,7 @@ def generate_inputs(data, filename_map): MinCostFlow.validate_required_inputs(filename_map) # will take the sources and write them to files, and repeats with targets - for node_type in ['sources', 'targets']: - nodes = data.get_node_columns([node_type]) - if nodes is None: - raise ValueError(f'No {node_type} found in the node files') + for node_type, nodes in data.get_node_columns_separate(['sources', 'targets'], "MinCostFlow").items(): # take nodes one column data frame, call sources/ target series nodes = nodes.loc[nodes[node_type]] # creates with the node type without headers diff --git a/spras/omicsintegrator1.py b/spras/omicsintegrator1.py index a92d7ecea..e57d3e660 100644 --- a/spras/omicsintegrator1.py +++ b/spras/omicsintegrator1.py @@ -1,6 +1,7 @@ from pathlib import Path from spras.containers import prepare_volume, run_container_and_log +from spras.dataset import MissingDataError from spras.interactome import reinsert_direction_col_mixed from spras.prm import PRM from spras.util import add_rank_column, duplicate_edges, raw_pathway_df @@ -64,14 +65,14 @@ def generate_inputs(data, filename_map): if data.contains_node_columns('prize'): # NODEID is always included in the node table - node_df = data.get_node_columns(['prize']) + node_df = data.get_node_columns(['prize'], 'Omics Integrator 1') elif data.contains_node_columns(['sources', 'targets']): # If there aren't prizes but are sources and targets, make prizes based on them - node_df = data.get_node_columns(['sources','targets']) + node_df = data.get_node_columns(['sources', 'targets'], 'Omics Integrator 1') node_df.loc[node_df['sources']==True, 'prize'] = 1.0 node_df.loc[node_df['targets']==True, 'prize'] = 1.0 else: - raise ValueError("Omics Integrator 1 requires node prizes or sources and targets") + raise MissingDataError("Omics Integrator 1", "(node prizes) or (sources and targets)") # Omics Integrator already gives warnings for strange prize values, so we won't here node_df.to_csv(filename_map['prizes'],sep='\t',index=False,columns=['NODEID','prize'],header=['name','prize']) diff --git a/spras/omicsintegrator2.py b/spras/omicsintegrator2.py index 8b97fa2d1..ac473ba32 100644 --- a/spras/omicsintegrator2.py +++ b/spras/omicsintegrator2.py @@ -3,7 +3,7 @@ import pandas as pd from spras.containers import prepare_volume, run_container_and_log -from spras.dataset import Dataset +from spras.dataset import Dataset, MissingDataError from spras.interactome import reinsert_direction_col_undirected from spras.prm import PRM from spras.util import add_rank_column, duplicate_edges @@ -36,14 +36,14 @@ def generate_inputs(data: Dataset, filename_map): if data.contains_node_columns('prize'): # NODEID is always included in the node table - node_df = data.get_node_columns(['prize']) + node_df = data.get_node_columns(['prize'], 'Omics Integrator 2') elif data.contains_node_columns(['sources', 'targets']): # If there aren't prizes but are sources and targets, make prizes based on them - node_df = data.get_node_columns(['sources', 'targets']) + node_df = data.get_node_columns(['sources', 'targets'], 'Omics Integrator 2') node_df.loc[node_df['sources']==True, 'prize'] = 1.0 node_df.loc[node_df['targets']==True, 'prize'] = 1.0 else: - raise ValueError("Omics Integrator 2 requires node prizes or sources and targets") + raise MissingDataError("Omics Integrator 2", "(node prizes) or (sources and targets)") # Omics Integrator already gives warnings for strange prize values, so we won't here node_df.to_csv(filename_map['prizes'], sep='\t', index=False, columns=['NODEID', 'prize'], header=['name','prize']) diff --git a/spras/pathlinker.py b/spras/pathlinker.py index c7cabc97b..62fba9742 100644 --- a/spras/pathlinker.py +++ b/spras/pathlinker.py @@ -37,9 +37,7 @@ def generate_inputs(data, filename_map): PathLinker.validate_required_inputs(filename_map) # Get sources and targets for node input file - sources_targets = data.get_node_columns(["sources", "targets"]) - if sources_targets is None: - return False + sources_targets = data.get_node_columns(["sources", "targets"], 'PathLinker') both_series = sources_targets.sources & sources_targets.targets for _index, row in sources_targets[both_series].iterrows(): warn_msg = row.NODEID + " has been labeled as both a source and a target." diff --git a/spras/responsenet.py b/spras/responsenet.py index bbc3b5255..029dd9e71 100644 --- a/spras/responsenet.py +++ b/spras/responsenet.py @@ -34,10 +34,7 @@ def generate_inputs(data, filename_map): ResponseNet.validate_required_inputs(filename_map) # will take the sources and write them to files, and repeats with targets - for node_type in ['sources', 'targets']: - nodes = data.get_node_columns([node_type]) - if nodes is None: - raise ValueError(f'No {node_type} found in the node files') + for node_type, nodes in data.get_node_columns_separate(['sources', 'targets'], "MinCostFlow").items(): # take nodes one column data frame, call sources/ target series nodes = nodes.loc[nodes[node_type]] # creates with the node type without headers diff --git a/spras/rwr.py b/spras/rwr.py index 4717aa064..12f71dbe1 100644 --- a/spras/rwr.py +++ b/spras/rwr.py @@ -19,13 +19,9 @@ def generate_inputs(data, filename_map): RWR.validate_required_inputs(filename_map) # Get sources and targets for node input file - if data.contains_node_columns(["sources","targets"]): - sources = data.get_node_columns(["sources"]) - targets = data.get_node_columns(["targets"]) - nodes = pd.DataFrame({'NODEID':sources['NODEID'].tolist() + targets['NODEID'].tolist()}) - nodes.to_csv(filename_map['nodes'],sep='\t',index=False,columns=['NODEID'],header=False) - else: - raise ValueError("Invalid node data") + sources_targets = data.get_node_columns_separate(["sources", "targets"], "RWR") + nodes = pd.DataFrame({'NODEID': sources_targets["sources"]['NODEID'].tolist() + sources_targets["targets"]['NODEID'].tolist()}) + nodes.to_csv(filename_map['nodes'],sep='\t',index=False,columns=['NODEID'],header=False) # Get edge data for network file edges = data.get_interactome() diff --git a/spras/strwr.py b/spras/strwr.py index 65ea9f923..78da75cc4 100644 --- a/spras/strwr.py +++ b/spras/strwr.py @@ -18,14 +18,8 @@ def generate_inputs(data, filename_map): ST_RWR.validate_required_inputs(filename_map) # Get separate source and target nodes for source and target files - if data.contains_node_columns(["sources","targets"]): - sources = data.get_node_columns(["sources"]) - sources.to_csv(filename_map['sources'],sep='\t',index=False,columns=['NODEID'],header=False) - - targets = data.get_node_columns(["targets"]) - targets.to_csv(filename_map['targets'],sep='\t',index=False,columns=['NODEID'],header=False) - else: - raise ValueError("Invalid node data") + for node_type, nodes in data.get_node_columns_separate(["sources", "targets"], "Source-Target RWR").items(): + nodes.to_csv(filename_map[node_type],sep='\t',index=False,columns=['NODEID'],header=False) # Get edge data for network file edges = data.get_interactome()