1919class Algorithm (fedavg .Algorithm ):
2020 """Algorithm providing MOON aggregation utilities."""
2121
22+ @staticmethod
23+ def _cast_tensor_like (tensor : torch .Tensor , reference : torch .Tensor ) -> torch .Tensor :
24+ """Cast a tensor to match a reference dtype (handles bool/int safely)."""
25+ if tensor .dtype == reference .dtype :
26+ return tensor
27+
28+ if torch .is_floating_point (reference ):
29+ return tensor .to (reference .dtype )
30+
31+ if reference .dtype == torch .bool :
32+ if torch .is_floating_point (tensor ):
33+ return tensor >= 0.5
34+ return tensor .ne (0 )
35+
36+ if torch .is_floating_point (tensor ):
37+ return torch .round (tensor ).to (reference .dtype )
38+
39+ return tensor .to (reference .dtype )
40+
2241 def moon_snapshot (self , weights : Mapping [str , torch .Tensor ]) -> dict :
2342 """Create a safe snapshot of the provided weights."""
2443 # Use a deepcopy to avoid in-place mutations on tensors; keep on CPU
@@ -35,16 +54,27 @@ def moon_aggregate_deltas(
3554
3655 total = sum (u .report .num_samples for u in updates ) or 1
3756
38- aggregated : OrderedDict [str , torch .Tensor ] = OrderedDict (
39- (name , torch .zeros_like (delta ))
40- for name , delta in deltas_received [0 ].items ()
41- )
57+ reference = deltas_received [0 ]
58+ aggregated : OrderedDict [str , torch .Tensor ] = OrderedDict ()
59+ for name , delta in reference .items ():
60+ if torch .is_floating_point (delta ):
61+ aggregated [name ] = torch .zeros_like (delta )
62+ else :
63+ aggregated [name ] = torch .zeros_like (
64+ delta , dtype = torch .get_default_dtype ()
65+ )
4266
4367 for u , delta in zip (updates , deltas_received ):
4468 w = (u .report .num_samples or 0 ) / total
4569 if w == 0.0 :
4670 continue
4771 for name , value in delta .items ():
48- aggregated [name ] += value * w
72+ target = aggregated [name ]
73+ if not torch .is_floating_point (value ) or value .dtype != target .dtype :
74+ value = value .to (target .dtype )
75+ aggregated [name ] = target + value * w
76+
77+ for name , ref_tensor in reference .items ():
78+ aggregated [name ] = self ._cast_tensor_like (aggregated [name ], ref_tensor )
4979
5080 return aggregated
0 commit comments