Skip to content

Commit 0e75860

Browse files
committed
Fixed the MOON aggregation dtype issue by accumulating non-floating deltas in a float buffer and then casting back to
the original dtype (including bool/int handling). This prevents the Long vs Float in-place add error during weighted aggregation.
1 parent 15923a7 commit 0e75860

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

examples/server_aggregation/moon/moon_algorithm.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@
1919
class Algorithm(fedavg.Algorithm):
2020
"""Algorithm providing MOON aggregation utilities."""
2121

22+
@staticmethod
23+
def _cast_tensor_like(tensor: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
24+
"""Cast a tensor to match a reference dtype (handles bool/int safely)."""
25+
if tensor.dtype == reference.dtype:
26+
return tensor
27+
28+
if torch.is_floating_point(reference):
29+
return tensor.to(reference.dtype)
30+
31+
if reference.dtype == torch.bool:
32+
if torch.is_floating_point(tensor):
33+
return tensor >= 0.5
34+
return tensor.ne(0)
35+
36+
if torch.is_floating_point(tensor):
37+
return torch.round(tensor).to(reference.dtype)
38+
39+
return tensor.to(reference.dtype)
40+
2241
def moon_snapshot(self, weights: Mapping[str, torch.Tensor]) -> dict:
2342
"""Create a safe snapshot of the provided weights."""
2443
# Use a deepcopy to avoid in-place mutations on tensors; keep on CPU
@@ -35,16 +54,27 @@ def moon_aggregate_deltas(
3554

3655
total = sum(u.report.num_samples for u in updates) or 1
3756

38-
aggregated: OrderedDict[str, torch.Tensor] = OrderedDict(
39-
(name, torch.zeros_like(delta))
40-
for name, delta in deltas_received[0].items()
41-
)
57+
reference = deltas_received[0]
58+
aggregated: OrderedDict[str, torch.Tensor] = OrderedDict()
59+
for name, delta in reference.items():
60+
if torch.is_floating_point(delta):
61+
aggregated[name] = torch.zeros_like(delta)
62+
else:
63+
aggregated[name] = torch.zeros_like(
64+
delta, dtype=torch.get_default_dtype()
65+
)
4266

4367
for u, delta in zip(updates, deltas_received):
4468
w = (u.report.num_samples or 0) / total
4569
if w == 0.0:
4670
continue
4771
for name, value in delta.items():
48-
aggregated[name] += value * w
72+
target = aggregated[name]
73+
if not torch.is_floating_point(value) or value.dtype != target.dtype:
74+
value = value.to(target.dtype)
75+
aggregated[name] = target + value * w
76+
77+
for name, ref_tensor in reference.items():
78+
aggregated[name] = self._cast_tensor_like(aggregated[name], ref_tensor)
4979

5080
return aggregated

0 commit comments

Comments
 (0)