Skip to content

Commit 00045ba

Browse files
committed
Adjusted FedAdp aggregation to accumulate deltas in floating point and then cast back to the original tensor dtypes, so integer/bool buffers (e.g., BatchNorm counters) no longer trigger Float→Long errors during weighted sums.
1 parent 56dad23 commit 00045ba

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

examples/server_aggregation/fedadp/fedadp_algorithm.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,15 @@ def fedadp_aggregate_deltas(
5454
total_samples = total_samples if total_samples > 0 else 1
5555

5656
# Global gradients as the sample-weighted average of client deltas
57-
global_grads: OrderedDict[str, torch.Tensor] = OrderedDict(
58-
(name, torch.zeros_like(tensor))
59-
for name, tensor in deltas_received[0].items()
60-
)
57+
global_grads: OrderedDict[str, torch.Tensor] = OrderedDict()
58+
reference_delta = deltas_received[0]
59+
for name, tensor in reference_delta.items():
60+
base = self._to_float_tensor(tensor)
61+
global_grads[name] = torch.zeros_like(base, dtype=base.dtype)
6162
for idx, delta in enumerate(deltas_received):
6263
weight = num_samples[idx] / total_samples
6364
for name, value in delta.items():
64-
global_grads[name] += value * weight
65+
global_grads[name] += self._to_float_tensor(value) * weight
6566

6667
# Compute adaptive weighting
6768
contribs = self._calc_contribution(
@@ -81,16 +82,19 @@ def fedadp_aggregate_deltas(
8182
weights[i] = (num_samples[i] * math.exp(c)) / denom
8283

8384
# Aggregate deltas with the computed weights
84-
agg: OrderedDict[str, torch.Tensor] = OrderedDict(
85-
(name, torch.zeros_like(tensor))
86-
for name, tensor in deltas_received[0].items()
87-
)
85+
agg: OrderedDict[str, torch.Tensor] = OrderedDict()
86+
for name, tensor in reference_delta.items():
87+
base = self._to_float_tensor(tensor)
88+
agg[name] = torch.zeros_like(base, dtype=base.dtype)
8889
for idx, delta in enumerate(deltas_received):
8990
w = weights[idx]
9091
if w == 0.0:
9192
continue
9293
for name, value in delta.items():
93-
agg[name] += value * w
94+
agg[name] += self._to_float_tensor(value) * w
95+
96+
for name, reference in reference_delta.items():
97+
agg[name] = self._cast_tensor_like(agg[name], reference)
9498

9599
return agg
96100

@@ -160,3 +164,31 @@ def to_np(t: torch.Tensor) -> np.ndarray:
160164
arr = to_np(tensor)
161165
flat = np.append(flat, -arr / lr)
162166
return flat
167+
168+
@staticmethod
169+
def _to_float_tensor(tensor: torch.Tensor) -> torch.Tensor:
170+
"""Ensure a tensor is floating for weighted accumulation."""
171+
if torch.is_floating_point(tensor):
172+
return tensor
173+
return tensor.to(torch.get_default_dtype())
174+
175+
@staticmethod
176+
def _cast_tensor_like(
177+
tensor: torch.Tensor, reference: torch.Tensor
178+
) -> torch.Tensor:
179+
"""Cast a tensor to match the dtype of a reference tensor."""
180+
if tensor.dtype == reference.dtype:
181+
return tensor
182+
183+
if torch.is_floating_point(reference):
184+
return tensor.to(reference.dtype)
185+
186+
if reference.dtype == torch.bool:
187+
if torch.is_floating_point(tensor):
188+
return tensor >= 0.5
189+
return tensor.ne(0)
190+
191+
if torch.is_floating_point(tensor):
192+
return torch.round(tensor).to(reference.dtype)
193+
194+
return tensor.to(reference.dtype)

0 commit comments

Comments
 (0)