Skip to content

Commit 8775b0b

Browse files
authored
Merge pull request #1234 from barfeous/bugs/augmentedGraphView
fix: Avoid modifying augmented graph view collection upon traversal
2 parents 27a9e91 + 4a31621 commit 8775b0b

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs

+5-9
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,19 @@ private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concre
8888

8989
public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
9090
{
91-
Trackable get_merged_trackable(Trackable x)
91+
void merged_trackable(Trackable x)
9292
{
9393
// TODO: complete it with new definitions `Asset` and `TrackableConstant`.
94-
return x;
9594
}
95+
9696
var trackable_objects = base.breadth_first_traversal();
9797

9898
foreach(var obj in _children_cache.Keys)
9999
{
100100
// skip the deletion of cache (maybe do it later).
101101
foreach(var pair in _children_cache[obj])
102102
{
103-
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
103+
merged_trackable(pair.Value);
104104
}
105105
}
106106

@@ -109,15 +109,11 @@ Trackable get_merged_trackable(Trackable x)
109109

110110
public List<(string, Trackable)> list_dependencies(Trackable obj)
111111
{
112-
IDictionary<string, Trackable> children;
113-
if (!_children_cache.ContainsKey(obj))
112+
if (!_children_cache.TryGetValue(obj, out var children))
114113
{
115114
children= new Dictionary<string, Trackable>();
116115
}
117-
else
118-
{
119-
children= _children_cache[obj];
120-
}
116+
121117
List<(string, Trackable)> res = new();
122118
foreach(var pair in obj.deserialization_dependencies(children))
123119
{

0 commit comments

Comments
 (0)