diff --git a/Snakefile b/Snakefile index 02f019e8d..65e24f15d 100644 --- a/Snakefile +++ b/Snakefile @@ -295,6 +295,9 @@ rule parse_output: params = reconstruction_params(wildcards.algorithm, wildcards.params).copy() params['dataset'] = input.dataset_file runner.parse_output(wildcards.algorithm, input.raw_file, output.standardized_file, params) + # TODO: cache heuristics result, store partial heuristics configuration file + # to allow this rule to update when heuristics change + _config.config.heuristics.validate_graph_from_file(output.standardized_file) # TODO: reuse in the future once we make summary work for mixed graphs. See https://github.com/Reed-CompBio/spras/issues/128 # Collect summary statistics for a single pathway diff --git a/spras/analysis/summary.py b/spras/analysis/summary.py index c8abc1cad..432dba0a4 100644 --- a/spras/analysis/summary.py +++ b/spras/analysis/summary.py @@ -1,10 +1,11 @@ from pathlib import Path -from statistics import median from typing import Iterable import networkx as nx import pandas as pd +from spras.statistics import compute_statistics, statistics_options + def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, algo_params: dict[str, dict], algo_with_params: list) -> pd.DataFrame: @@ -47,44 +48,11 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg # Save the network name, number of nodes, number edges, and number of connected components nw_name = str(file_path) - number_nodes = nw.number_of_nodes() - number_edges = nw.number_of_edges() - ncc = nx.number_connected_components(nw) - - # Save the max/median degree, average clustering coefficient, and density - if number_nodes == 0: - max_degree = 0 - median_degree = 0.0 - density = 0.0 - else: - degrees = [deg for _, deg in nw.degree()] - max_degree = max(degrees) - median_degree = median(degrees) - density = nx.density(nw) - - cc = list(nx.connected_components(nw)) - # Save the max diameter - # Use diameter only for components with ≥2 nodes (singleton components have diameter 0) - diameters = [ - nx.diameter(nw.subgraph(c).copy()) if len(c) > 1 else 0 - for c in cc - ] - max_diameter = max(diameters, default=0) - - # Save the average path lengths - # Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0) - avg_path_lengths = [ - nx.average_shortest_path_length(nw.subgraph(c).copy()) if len(c) > 1 else 0.0 - for c in cc - ] - - if len(avg_path_lengths) != 0: - avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths) - else: - avg_path_len = 0.0 + + graph_statistics = compute_statistics(nw, statistics_options) # Initialize list to store current network information - cur_nw_info = [nw_name, number_nodes, number_edges, ncc, density, max_degree, median_degree, max_diameter, avg_path_len] + cur_nw_info = [nw_name, *graph_statistics.values()] # Iterate through each node property and save the intersection with the current network for node_list in nodes_by_col: @@ -104,7 +72,7 @@ def summarize_networks(file_paths: Iterable[Path], node_table: pd.DataFrame, alg nw_info.append(cur_nw_info) # Prepare column names - col_names = ['Name', 'Number of nodes', 'Number of edges', 'Number of connected components', 'Density', 'Max degree', 'Median degree', 'Max diameter', 'Average path length'] + col_names = ['Name', *statistics_options] col_names.extend(nodes_by_col_labs) col_names.append('Parameter combination') diff --git a/spras/config/config.py b/spras/config/config.py index 25e6f72de..51d3daf3b 100644 --- a/spras/config/config.py +++ b/spras/config/config.py @@ -78,6 +78,8 @@ def __init__(self, raw_config: dict[str, Any]): self.container_settings = ProcessedContainerSettings.from_container_settings(parsed_raw_config.containers, self.hash_length) # The list of algorithms to run in the workflow. Each is a dict with 'name' as an expected key. self.algorithms = None + # The heuristic handler + self.heuristics = parsed_raw_config.heuristics # A nested dict mapping algorithm names to dicts that map parameter hashes to parameter combinations. # Only includes algorithms that are set to be run with 'include: true'. self.algorithm_params = None diff --git a/spras/config/heuristics.py b/spras/config/heuristics.py new file mode 100644 index 000000000..52c4839c6 --- /dev/null +++ b/spras/config/heuristics.py @@ -0,0 +1,105 @@ +import os + +import networkx as nx +from pydantic import BaseModel, ConfigDict + +from spras.interval import Interval +from spras.statistics import compute_statistics, statistics_options + +all = ['GraphHeuristicsError', 'GraphHeuristic'] + +class GraphHeuristicsError(RuntimeError): + """ + Represents an error arising from a graph algorithm output + not meeting the necessary graph heuristisc. + """ + failed_heuristics: list[tuple[str, float | int, list[Interval]]] + + @staticmethod + def format_failed_heuristic(heuristic: tuple[str, float | int, list[Interval]]) -> str: + name, desired, intervals = heuristic + if len(intervals) == 1: + interval_string = str(intervals[0]) + else: + formatted_intervals = ", ".join([str(interval) for interval in intervals]) + interval_string = f"one of the intervals ({formatted_intervals})" + return f"{name} expected {desired} in interval {interval_string}" + @staticmethod + def to_string(failed_heuristics: list[tuple[str, float | int, list[Interval]]]): + formatted_heuristics = [ + GraphHeuristicsError.format_failed_heuristic(heuristic) for heuristic in failed_heuristics + ] + + formatted_heuristics = "\n".join([f"- {formatted_heuristics}" for heuristic in formatted_heuristics]) + return f"The following heuristics failed:\n{formatted_heuristics}" + + def __init__(self, failed_heuristics: list[tuple[str, float | int, list[Interval]]]): + super().__init__(GraphHeuristicsError.to_string(failed_heuristics)) + + self.failed_heuristics = failed_heuristics + + def __str__(self) -> str: + return GraphHeuristicsError.to_string(self.failed_heuristics) + +class GraphHeuristics(BaseModel): + number_of_nodes: Interval | list[Interval] = [] + number_of_edges: Interval | list[Interval] = [] + number_of_connected_components: Interval | list[Interval] = [] + density: Interval | list[Interval] = [] + + max_degree: Interval | list[Interval] = [] + median_degree: Interval | list[Interval] = [] + max_diameter: Interval | list[Interval] = [] + average_path_length: Interval | list[Interval] = [] + + def validate_graph(self, graph: nx.DiGraph): + statistics_dictionary = { + 'Number of nodes': self.number_of_nodes, + 'Number of edges': self.number_of_edges, + 'Number of connected components': self.number_of_connected_components, + 'Density': self.density, + 'Max degree': self.max_degree, + 'Median degree': self.median_degree, + 'Max diameter': self.max_diameter, + 'Average path length': self.average_path_length + } + + # quick assert: is statistics_dictionary exhaustive? + assert set(statistics_dictionary.keys()) == set(statistics_options) + + stats = compute_statistics( + graph, + list(k for k, v in statistics_dictionary.items() if not isinstance(v, list) or len(v) != 0) + ) + + failed_heuristics: list[tuple[str, float | int, list[Interval]]] = [] + for key, value in stats.items(): + intervals = statistics_dictionary[key] + if not isinstance(intervals, list): intervals = [intervals] + + for interval in intervals: + if not interval.mem(value): + failed_heuristics.append((key, value, intervals)) + break + + if len(failed_heuristics) != 0: + raise GraphHeuristicsError(failed_heuristics) + + model_config = ConfigDict(extra='forbid') + + def validate_graph_from_file(self, path: str | os.PathLike): + """ + Takes in a graph produced by PRM#parse_output, + and throws a GraphHeuristicsError if it fails the heuristics in `self`. + """ + # TODO: re-use from summary.py once we have a mixed/hypergraph library + G: nx.DiGraph = nx.read_edgelist(path, data=(('Rank', str), ('Direction', str)), create_using=nx.DiGraph) + + # We explicitly use `list` here to stop add_edge + # from expanding our iterator infinitely. + for source, target, data in list(G.edges(data=True)): + if data["Direction"] == 'U': + G.add_edge(target, source, data=data) + pass + + return self.validate_graph(G) diff --git a/spras/config/schema.py b/spras/config/schema.py index a1936b0c0..3678a541c 100644 --- a/spras/config/schema.py +++ b/spras/config/schema.py @@ -16,6 +16,7 @@ from pydantic import AfterValidator, BaseModel, ConfigDict from spras.config.container_schema import ContainerSettings +from spras.config.heuristics import GraphHeuristics from spras.config.util import CaseInsensitiveEnum # Most options here have an `include` property, @@ -151,6 +152,8 @@ class RawConfig(BaseModel): reconstruction_settings: ReconstructionSettings + heuristics: GraphHeuristics = GraphHeuristics() + # We include use_attribute_docstrings here to preserve the docstrings # after attributes at runtime (for future JSON schema generation) model_config = ConfigDict(extra='forbid', use_attribute_docstrings=True) diff --git a/spras/interval.py b/spras/interval.py new file mode 100644 index 000000000..b65f87a7c --- /dev/null +++ b/spras/interval.py @@ -0,0 +1,234 @@ +""" +Utilities for defining inequality intervals (e.g. l < x <= u) + +For graph heuristics, we allow inequality intervals of the form (num) < (id)?. For example, +we can say "1500 <" for "1500 < x", or "1000 < x < 2000", etc. + +[If there is ever a library that does this, we should replace this code with that library.] +""" + +import tokenize +from enum import Enum +from io import BytesIO +from typing import Any, ClassVar, Optional, Self, cast + +from pydantic import model_serializer, model_validator +from pydantic.dataclasses import dataclass + + +class Operand(Enum): + LT = "<" + LTE = "<=" + EQ = "=" + GTE = ">=" + GT = ">" + + @classmethod + def from_str(cls, string: str) -> Optional[Self]: + return next((enum for enum in list(cls) if enum.value == string), None) + + def is_closed(self) -> bool: + """Whether this is a closed inequality. We consider = to be closed.""" + match self: + case Operand.LTE: return True + case Operand.EQ: return True + case Operand.GTE: return True + return False + + def as_closed(self): + """Closes an operand. Eq does not get modified.""" + match self: + case Operand.LT: return Operand.LTE + case Operand.GT: return Operand.GTE + return self + + def as_opened(self): + """Opens an operand. Eq does not get modified.""" + match self: + case Operand.LTE: return Operand.LT + case Operand.GTE: return Operand.GT + return self + + def with_closed(self, closed: bool): return self.as_closed() if closed else self.as_opened() + + def compare(self, left, right) -> bool: + match self: + case Operand.LT: return left < right + case Operand.LTE: return left <= right + case Operand.EQ: return left == right + case Operand.GTE: return left >= right + case Operand.GT: return left > right + + def flip(self): + match self: + case Operand.LT: return Operand.GT + case Operand.LTE: return Operand.GTE + case Operand.EQ: return Operand.EQ + case Operand.GTE: return Operand.LTE + case Operand.GT: return Operand.LT + + @classmethod + def combine(cls, left: Self, right: Self): + """Combines two operands, returning None if the operands don't combine well.""" + match (left, right): + case (Operand.LTE, Operand.LTE): return Operand.LTE + case (Operand.LT, Operand.LTE): return Operand.LT + case (Operand.LTE, Operand.LT): return Operand.LT + case (Operand.LT, Operand.LT): return Operand.LT + case (Operand.EQ, op): return op + case (op, Operand.EQ): return op + case (Operand.GTE, Operand.GTE): return Operand.GTE + case (Operand.GT, Operand.GTE): return Operand.GT + case (Operand.GTE, Operand.GT): return Operand.GT + case (Operand.GT, Operand.GT): return Operand.GT + return None + +@dataclass +class Interval: + EMPTY_STRING: ClassVar[str] = "{empty interval}" + + lower: Optional[float] + upper: Optional[float] + lower_closed: bool + upper_closed: bool + + def mem(self, num: float) -> bool: + if self.lower is not None: + meets_lower = self.lower <= num if self.lower_closed else self.lower < num + else: + meets_lower = True + + if self.upper is not None: + meets_upper = num <= self.upper if self.upper_closed else num < self.upper + else: + meets_upper = True + + return meets_lower and meets_upper + + @classmethod + def single(cls, num: float) -> Self: + return cls(lower=num, upper=num, lower_closed=True, upper_closed=True) + + @classmethod + def left_operand(cls, operand: Operand, num: float) -> Self: + """Creates an interval whose operand is on the left (e.g. <300)""" + match operand: + case Operand.LT: return cls(lower=None, upper=num, lower_closed=False, upper_closed=False) + case Operand.LTE: return cls(lower=None, upper=num, lower_closed=False, upper_closed=True) + case Operand.EQ: return cls.single(num) + case Operand.GTE: return cls(lower=num, upper=None, lower_closed=True, upper_closed=False) + case Operand.GT: return cls(lower=num, upper=None, lower_closed=False, upper_closed=False) + + @classmethod + def right_operand(cls, num: float, operand: Operand) -> Self: + """Creates an interval whose operand is on the right (e.g. 300<)""" + # TODO: remove cast? + return cast(Self, Interval.left_operand(operand.flip(), num)) + + @classmethod + def from_string(cls, input: str) -> Self: + # We can't do a normal string#split here for cases like "1500<" + tokens = [t.string for t in tokenize.tokenize(BytesIO(input.encode('utf-8')).readline) if t.string != ""] + tokens.pop(0) # drop utf-8 indicator + + assert len(tokens) != 0 + + def parse_num(numstr: str) -> Optional[int]: + # Allow pythonic separators + try: + return int(numstr.replace("_", "")) + except: + return None + + def is_id(idstr: str) -> bool: return idstr.isidentifier() + + # Case 1: (id?) operand number + if is_id(tokens[0]): + # No other cases have an id at the beginning: we get rid of it. + tokens.pop(0) + + operand = Operand.from_str(tokens[0]) + if operand is not None: + # (cont.) Case 1: (id?) operand number + number = parse_num(tokens[1]) + assert number is not None, f"found operand {operand.value} and expected a number, but found {tokens[1]} instead." + return cls.left_operand(operand, number) + + # All other cases have a number + number = parse_num(tokens.pop(0)) + assert number is not None, f"expected a number, got {input} instead" + + # Case 2: number + if len(tokens) == 0: + return cls.single(number) + + # All other cases have an operand + operand = Operand.from_str(tokens.pop(0)) + assert operand is not None, f"got {number}, expected an operand afterward." + + # Case 3: number operand (id?) + if len(tokens) == 0 or len(tokens) == 1: + if len(tokens) == 1: assert is_id(tokens[0]) + return cls.right_operand(number, operand) + + # Case 4: number operand id operand number + id = tokens.pop(0) + assert is_id(id), f"got an inequality of the form {number} {operand.value} and expected nothing or another identifier, but got {id} instead." + + second_operand_str = tokens.pop(0) + second_operand = Operand.from_str(second_operand_str) + assert second_operand is not None, f"got an inequality of the form {number} {operand.value} {id} and was expecting an operand, but got {second_operand_str} instead." + + second_number_str = tokens.pop(0) + second_number = parse_num(second_number_str) + assert second_number is not None, f"got an inequality of the form {number} {operand.value} {id} {second_operand.value} and was expecting a number, but got {second_number_str} instead." + + # don't want equals operands in a double inequality (a < b < c) + assert operand is not Operand.EQ and second_operand is not Operand.EQ, f"in a double inequality, neither operand can be '='!" + + # are our two numbers valid? + combined_operand = Operand.combine(operand, second_operand) + assert combined_operand is not None, f"operands {operand.value} and {second_operand.value} must combine well with each other!" + assert combined_operand.compare(number, second_number), f"{number} {operand.value} {second_number} does not hold!" + + if combined_operand.as_opened() == Operand.LT: + return cls( + lower=number, + upper=second_number, + lower_closed=operand.is_closed(), + upper_closed=second_operand.is_closed() + ) + else: + return cls( + lower=second_number, + upper=number, + lower_closed=second_operand.is_closed(), + upper_closed=operand.is_closed() + ) + + def __str__(self) -> str: + if self.lower is None and self.upper is None: return Interval.EMPTY_STRING + if self.lower is None: + return Operand.LT.with_closed(self.upper_closed).value + " " + str(self.upper) + if self.upper is None: + return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + + if self.lower == self.upper and self.lower_closed and self.upper_closed: return str(self.lower) + + return str(self.lower) + " " + Operand.LT.with_closed(self.lower_closed).value + " x " \ + + Operand.LT.with_closed(self.upper_closed).value + " " + str(self.upper) + + def __repr__(self) -> str: + return f"Interval[{str(self)}]" + + # For parsing Intervals automatically with pydantic. + @model_validator(mode="before") + @classmethod + def from_literal(cls, data: Any) -> Any: + if isinstance(data, int) or isinstance(data, float) or isinstance(data, str): + return vars(cls.from_string(str(data))) + return data + + @model_serializer(mode='plain') + def serialize_model(self) -> str: + return str(self) diff --git a/spras/statistics.py b/spras/statistics.py new file mode 100644 index 000000000..222051d23 --- /dev/null +++ b/spras/statistics.py @@ -0,0 +1,91 @@ +""" +Graph statistics, used to power summary.py. + +We allow for arbitrary computation of any specific statistic on some graph, +computing more than necessary if we have dependencies. See the top level +`statistics_computation` dictionary for usage. +""" + +import itertools +from statistics import median +from typing import Callable + +import networkx as nx + + +def compute_degree(graph: nx.DiGraph) -> tuple[int, float]: + """ + Computes the (max, median) degree of a `graph`. + """ + # number_of_nodes is a cheap call + if graph.number_of_nodes() == 0: + return (0, 0.0) + else: + degrees = [deg for _, deg in graph.degree()] + return max(degrees), median(degrees) + +def compute_on_cc(directed_graph: nx.DiGraph) -> tuple[int, float]: + graph: nx.Graph = directed_graph.to_undirected() + cc = list(nx.connected_components(graph)) + # Save the max diameter + # Use diameter only for components with ≥2 nodes (singleton components have diameter 0) + diameters = [ + nx.diameter(graph.subgraph(c).copy()) if len(c) > 1 else 0 + for c in cc + ] + max_diameter = max(diameters, default=0) + + # Save the average path lengths + # Compute average shortest path length only for components with ≥2 nodes (undefined for singletons, set to 0.0) + avg_path_lengths = [ + nx.average_shortest_path_length(graph.subgraph(c).copy()) if len(c) > 1 else 0.0 + for c in cc + ] + + if len(avg_path_lengths) != 0: + avg_path_len = sum(avg_path_lengths) / len(avg_path_lengths) + else: + avg_path_len = 0.0 + + return max_diameter, avg_path_len + +# The type signature on here is quite bad. I would like to say that an n-tuple has n-outputs. +statistics_computation: dict[tuple[str, ...], Callable[[nx.DiGraph], tuple[float | int, ...]]] = { + ('Number of nodes',): lambda graph : (graph.number_of_nodes(),), + ('Number of edges',): lambda graph : (graph.number_of_edges(),), + ('Number of connected components',): lambda graph : (nx.number_connected_components(graph.to_undirected()),), + ('Density',): lambda graph : (nx.density(graph),), + + ('Max degree', 'Median degree'): compute_degree, + ('Max diameter', 'Average path length'): compute_on_cc, +} + +# All of the keys inside statistics_computation, flattened. +statistics_options: list[str] = list(itertools.chain(*(list(key) for key in statistics_computation.keys()))) + +def compute_statistics(graph: nx.DiGraph, statistics: list[str]) -> dict[str, float | int]: + """ + Computes `statistics` for a graph corresponding to the top-level `statistics` dictionary + in this file. + """ + + # early-scan cutoff for statistics: + # we want to err as soon as possible + for stat in statistics: + if stat not in statistics_options: + raise RuntimeError(f"Statistic {stat} not a computable statistics! Available statistics: {statistics_options}") + + # now, we can compute statistics only + computed_statistics: dict[str, float | int] = dict() + for statistic_tuple, compute in statistics_computation.items(): + # when we want them + if not set(statistic_tuple).isdisjoint(set(statistics)): + computed_tuple = compute(graph) + assert len(statistic_tuple) == len(computed_tuple), f"bad tuple length for {statistic_tuple}" + + current_computed_statistics = zip(statistic_tuple, computed_tuple, strict=True) + for stat, value in current_computed_statistics: + computed_statistics[stat] = value + + # (and return only the statistics we wanted) + return {key: computed_statistics[key] for key in statistics} diff --git a/test/heuristics/__init__.py b/test/heuristics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/heuristics/fixtures/empty.txt b/test/heuristics/fixtures/empty.txt new file mode 100644 index 000000000..e69de29bb diff --git a/test/heuristics/fixtures/nonempty.txt b/test/heuristics/fixtures/nonempty.txt new file mode 100644 index 000000000..8e9f8ac96 --- /dev/null +++ b/test/heuristics/fixtures/nonempty.txt @@ -0,0 +1 @@ +A B 1 D diff --git a/test/heuristics/fixtures/undirected.txt b/test/heuristics/fixtures/undirected.txt new file mode 100644 index 000000000..627d30073 --- /dev/null +++ b/test/heuristics/fixtures/undirected.txt @@ -0,0 +1 @@ +A B 1 U diff --git a/test/heuristics/test_heuristics.py b/test/heuristics/test_heuristics.py new file mode 100644 index 000000000..8011f5377 --- /dev/null +++ b/test/heuristics/test_heuristics.py @@ -0,0 +1,27 @@ +from pathlib import Path + +import pytest + +from spras.config.heuristics import GraphHeuristics, GraphHeuristicsError + +FIXTURES_DIR = Path('test', 'heuristics', 'fixtures') + +class TestHeuristics: + def parse(self, heuristics: dict) -> GraphHeuristics: + return GraphHeuristics.model_validate(heuristics) + + def test_nonempty(self): + self.parse({ 'number_of_nodes': '>0', 'number_of_edges': '1' } + ).validate_graph_from_file(FIXTURES_DIR / 'nonempty.txt') + + def test_empty(self): + self.parse({ 'number_of_nodes': '<1' } + ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') + + with pytest.raises(GraphHeuristicsError): + self.parse({ 'number_of_nodes': '0<' } + ).validate_graph_from_file(FIXTURES_DIR / 'empty.txt') + + def test_undirected(self): + self.parse({ 'number_of_nodes': '1 < x < 3', 'number_of_edges': 2 } + ).validate_graph_from_file(FIXTURES_DIR / 'undirected.txt') diff --git a/test/test_interval.py b/test/test_interval.py new file mode 100644 index 000000000..1481d1a79 --- /dev/null +++ b/test/test_interval.py @@ -0,0 +1,21 @@ +from spras.interval import Interval + + +class TestInterval: + def test_number(self): + assert Interval.single(5) == Interval(lower=5, upper=5, lower_closed=True, upper_closed=True) + assert Interval.from_string("5") == Interval.single(5) + + def test_interval_gt_0(self): + assert Interval.from_string(">0") == Interval(lower=0, upper=None, lower_closed=False, upper_closed=False) + + def test_string_permutations(self): + assert Interval.from_string("<5") == Interval.from_string("< 5") + assert Interval.from_string("5<") == Interval.from_string("5 < x") + assert Interval.from_string("6<") == Interval.from_string("x > 6") + assert Interval.from_string("100 <") == Interval.from_string(">100") + assert Interval.from_string("200 >= x > 100") == Interval.from_string("100 < x <= 200") + + def test_orientation(self): + assert Interval.from_string("10<").upper is None + assert Interval.from_string("10<").lower == 10.0