Skip to content

Commit

Permalink
add induced feature
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Mar 11, 2024
1 parent 6f28f68 commit 7801fa5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
34 changes: 26 additions & 8 deletions networkframe/groupby.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
class NodeGroupBy:
"""A class for grouping a `NetworkFrame` by a set of labels."""

def __init__(self, frame, source_groupby, target_groupby):
def __init__(self, frame, source_groupby, target_groupby, induced: bool = False):
"""Groupby on nodes.
Parameters
----------
frame : _type_
frame
_description_
source_groupby : _type_
source_groupby
_description_
target_groupby : _type_
target_groupby
_description_
induced
_description_
"""
self._frame = frame
self._source_groupby = source_groupby
self._target_groupby = target_groupby
self.induced = induced

if source_groupby is None:
self._axis = 1
Expand All @@ -29,6 +32,18 @@ def __init__(self, frame, source_groupby, target_groupby):
if self.has_target_groups:
self.target_group_names = list(target_groupby.groups.keys())

def __len__(self):
"""Return the number of groups."""
if self._axis == "both":
if self.induced:
return len(self._source_groupby)
else:
return len(self._source_groupby) * len(self._target_groupby)
elif self._axis == 0:
return len(self._source_groupby)
elif self._axis == 1:
return len(self._target_groupby)

@property
def has_source_groups(self):
"""Whether the frame has row groups."""
Expand All @@ -44,10 +59,13 @@ def __iter__(self):
if self._axis == "both":
for source_group, source_objects in self._source_groupby:
for target_group, target_objects in self._target_groupby:
yield (
(source_group, target_group),
self._frame.loc[source_objects.index, target_objects.index],
)
if self.induced and source_group != target_group:
continue
else:
yield (
(source_group, target_group),
self._frame.loc[source_objects.index, target_objects.index],
)
elif self._axis == 0:
for source_group, source_objects in self._source_groupby:
yield source_group, self._frame.loc[source_objects.index]
Expand Down
42 changes: 23 additions & 19 deletions networkframe/networkframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def select_component_from_node(
)

def groupby_nodes(
self, by: Union[Any, list], axis: EdgeAxisType = "both", **kwargs
self, by: Union[Any, list], axis: EdgeAxisType = "both", induced=False, **kwargs
) -> "NodeGroupBy":
"""Group the frame by node data for the source or target (or both).
Expand All @@ -837,6 +837,9 @@ def groupby_nodes(
axis
Whether to group by the source nodes (`source` or `0`), target nodes
(`target` or `0`), or both (`both`).
induced
Whether to only yield groups over induced subgraphs, as opposed to all
subgraphs.
kwargs
Additional keyword arguments to pass to [pandas.DataFrame.groupby][].
Expand Down Expand Up @@ -891,7 +894,9 @@ def groupby_nodes(
else:
raise ValueError("Axis must be 0 or 1 or 'both'")

return NodeGroupBy(self, source_nodes_groupby, target_nodes_groupby)
return NodeGroupBy(
self, source_nodes_groupby, target_nodes_groupby, induced=induced
)

@property
def loc(self) -> "LocIndexer":
Expand Down Expand Up @@ -949,20 +954,17 @@ def __eq__(self, other: object) -> bool:
edges1 = self.edges
edges2 = other.edges
if not nodes1.sort_index().equals(nodes2.sort_index()):
print("diff nodes")
return False

index1 = edges1.index.sort_values()
index2 = edges2.index.sort_values()
if not index1.equals(index2):
print("diff index")
return False

# sort the edges the same way (note the index1 twice)
edges1 = edges1.loc[index1]
edges2 = edges2.loc[index1]
if not edges1.equals(edges2):
print("diff edges")
return False

return True
Expand Down Expand Up @@ -1102,20 +1104,22 @@ def __getitem__(self, args):

if row_index.equals(col_index):
nodes = source_nodes
return NetworkFrame(
nodes,
edges,
directed=self._frame.directed,
validate=False,
)
return self._frame._return(nodes=nodes, edges=edges, inplace=False)
# return NetworkFrame(
# nodes,
# edges,
# directed=self._frame.directed,
# validate=False,
# )
else:
nodes = pd.concat([source_nodes, target_nodes], copy=False, sort=False)
nodes = nodes.loc[~nodes.index.duplicated(keep="first")]
return NetworkFrame(
nodes,
edges,
directed=self._frame.directed,
sources=row_index,
targets=col_index,
validate=False,
)
return self._frame._return(nodes=nodes, edges=edges, inplace=False)
# return NetworkFrame(
# nodes,
# edges,
# directed=self._frame.directed,
# sources=row_index,
# targets=col_index,
# validate=False,
# )

0 comments on commit 7801fa5

Please sign in to comment.