diff --git a/trumania/core/relationship.py b/trumania/core/relationship.py index 6e6b943..78f8107 100644 --- a/trumania/core/relationship.py +++ b/trumania/core/relationship.py @@ -11,15 +11,15 @@ # There are a lot of somewhat ugly optimizations here like in-place mutations, # caching, or usage of numpy instead of a more readable pandas alternative. The -# reason is the methods of this filetend to be called a large amount of time +# reason is the methods of this file tend to be called a large amount of time # in inner loop of the simulation, optimizing them make the whole simulation # faster. -class Relations(object): +class OutgoingRelations(object): """ - This entity contains all the "to" sides of the relationships of a given - "from", together with the related weights. + For a given "from", this entity contains all the "to" sides of the + relationship with the related weights. This data structure seems to be the most optimal since it corresponds to a cached group-by result, and those group-by are expensive in the select_one @@ -45,8 +45,10 @@ def from_tuples(from_ids, to_ids, weights): a relationship is built here for each "line" read across those 3 arrays. - This methods builds one instance of Relations for each unique from_id + This methods builds one instance of OutgoingRelations for each unique from_id value, containing all the to_id's it is related to. + + :returns Dictionary { id1 -> OutgoingRelations1, id2 -> OutgoingRelations2, ... } """ from_ids = np.array(from_ids) @@ -60,11 +62,12 @@ def from_tuples(from_ids, to_ids, weights): order = from_ids.argsort() ordered = zip(from_ids[order], to_ids[order], weights[order]) + # Find for every unique id in from_ids their matching "to" relations def _relations(): # itertools.groupby is much faster than pandas for from_id, tuples in itertools.groupby(ordered, lambda t: t[0]): - to_ids, weights = list(zip(*tuples))[1: 3] - yield from_id, Relations(list(to_ids), list(weights)) + to_ids, weights = list(zip(*tuples))[1:3] + yield from_id, OutgoingRelations(list(to_ids), list(weights)) return {from_id: relz for from_id, relz in _relations()} @@ -72,7 +75,7 @@ def plus(self, other): """ Merge function for 2 sets of relations all starting from the same "from" """ - return Relations( + return OutgoingRelations( np.hstack([self.to_ids, other.to_ids]), np.hstack([self.weights, other.weights])) @@ -83,7 +86,7 @@ def minus(self, other): """ removed_indices = np.argwhere( [idx in other.to_ids for idx in self.to_ids]) - return Relations( + return OutgoingRelations( np.delete(self.to_ids, removed_indices), np.delete(self.weights, removed_indices)) @@ -157,7 +160,7 @@ def add_relations(self, from_ids, to_ids, weights=1): self.grouped = utils.merge_2_dicts( self.grouped, - Relations.from_tuples(from_ids, to_ids, weights), + OutgoingRelations.from_tuples(from_ids, to_ids, weights), lambda r1, r2: r1.plus(r2)) def add_grouped_relations(self, from_ids, grouped_ids): @@ -185,7 +188,7 @@ def remove_relations(self, from_ids, to_ids): self.grouped = utils.merge_2_dicts( self.grouped, - Relations.from_tuples(from_ids, to_ids, weights=0), + OutgoingRelations.from_tuples(from_ids, to_ids, weights=0), lambda r1, r2: r1.minus(r2)) def get_relations(self, from_ids=None): diff --git a/trumania/core/util_functions.py b/trumania/core/util_functions.py index f5495d5..727c8ef 100644 --- a/trumania/core/util_functions.py +++ b/trumania/core/util_functions.py @@ -89,23 +89,33 @@ def merge_2_dicts(dict1, dict2, value_merge_func=None): if dict1 is None: return dict2 - def merged_value(key): - if key not in dict1: - return dict2[key] - elif key not in dict2: - return dict1[key] - else: - if value_merge_func is None: - raise ValueError( - "Conflict in merged dictionaries: merge function not " - "provided but key {} exists in both dictionaries".format( - key)) - - return value_merge_func(dict1[key], dict2[key]) - - keys = set(dict1.keys()) | set(dict2.keys()) - - return {key: merged_value(key) for key in keys} + if dict1 == dict2: + for k, v in dict1.items(): + dict1[k] = value_merge_func(v, v) + return dict1 + + dict1_set = set(dict1) + dict2_set = set(dict2) + + keys_to_merge = dict1_set.intersection(dict2_set) + + if len(keys_to_merge) != 0 and value_merge_func is None: + raise ValueError( + "Conflict in merged dictionaries: merge function not " + "provided but keys {} exists in both dictionaries".format( + keys_to_merge)) + + values_merged = dict() + + for key_to_merge in keys_to_merge: + old_value1 = dict1[key_to_merge] + old_value2 = dict2[key_to_merge] + + new_value = value_merge_func(old_value1, old_value2) + + values_merged[key_to_merge] = new_value + + return {**dict1, **dict2, **values_merged} def df_concat(d1, d2):