@@ -54,14 +54,15 @@ def fedadp_aggregate_deltas(
5454 total_samples = total_samples if total_samples > 0 else 1
5555
5656 # Global gradients as the sample-weighted average of client deltas
57- global_grads : OrderedDict [str , torch .Tensor ] = OrderedDict (
58- (name , torch .zeros_like (tensor ))
59- for name , tensor in deltas_received [0 ].items ()
60- )
57+ global_grads : OrderedDict [str , torch .Tensor ] = OrderedDict ()
58+ reference_delta = deltas_received [0 ]
59+ for name , tensor in reference_delta .items ():
60+ base = self ._to_float_tensor (tensor )
61+ global_grads [name ] = torch .zeros_like (base , dtype = base .dtype )
6162 for idx , delta in enumerate (deltas_received ):
6263 weight = num_samples [idx ] / total_samples
6364 for name , value in delta .items ():
64- global_grads [name ] += value * weight
65+ global_grads [name ] += self . _to_float_tensor ( value ) * weight
6566
6667 # Compute adaptive weighting
6768 contribs = self ._calc_contribution (
@@ -81,16 +82,19 @@ def fedadp_aggregate_deltas(
8182 weights [i ] = (num_samples [i ] * math .exp (c )) / denom
8283
8384 # Aggregate deltas with the computed weights
84- agg : OrderedDict [str , torch .Tensor ] = OrderedDict (
85- ( name , torch . zeros_like ( tensor ))
86- for name , tensor in deltas_received [ 0 ]. items ( )
87- )
85+ agg : OrderedDict [str , torch .Tensor ] = OrderedDict ()
86+ for name , tensor in reference_delta . items ():
87+ base = self . _to_float_tensor ( tensor )
88+ agg [ name ] = torch . zeros_like ( base , dtype = base . dtype )
8889 for idx , delta in enumerate (deltas_received ):
8990 w = weights [idx ]
9091 if w == 0.0 :
9192 continue
9293 for name , value in delta .items ():
93- agg [name ] += value * w
94+ agg [name ] += self ._to_float_tensor (value ) * w
95+
96+ for name , reference in reference_delta .items ():
97+ agg [name ] = self ._cast_tensor_like (agg [name ], reference )
9498
9599 return agg
96100
@@ -160,3 +164,31 @@ def to_np(t: torch.Tensor) -> np.ndarray:
160164 arr = to_np (tensor )
161165 flat = np .append (flat , - arr / lr )
162166 return flat
167+
168+ @staticmethod
169+ def _to_float_tensor (tensor : torch .Tensor ) -> torch .Tensor :
170+ """Ensure a tensor is floating for weighted accumulation."""
171+ if torch .is_floating_point (tensor ):
172+ return tensor
173+ return tensor .to (torch .get_default_dtype ())
174+
175+ @staticmethod
176+ def _cast_tensor_like (
177+ tensor : torch .Tensor , reference : torch .Tensor
178+ ) -> torch .Tensor :
179+ """Cast a tensor to match the dtype of a reference tensor."""
180+ if tensor .dtype == reference .dtype :
181+ return tensor
182+
183+ if torch .is_floating_point (reference ):
184+ return tensor .to (reference .dtype )
185+
186+ if reference .dtype == torch .bool :
187+ if torch .is_floating_point (tensor ):
188+ return tensor >= 0.5
189+ return tensor .ne (0 )
190+
191+ if torch .is_floating_point (tensor ):
192+ return torch .round (tensor ).to (reference .dtype )
193+
194+ return tensor .to (reference .dtype )
0 commit comments