Skip to content

Commit fdb256b

Browse files
committed
Updated the MOON example in Plato so it can run ResNet‑18 (CIFAR‑10) and VGG‑16 (CINIC‑10) with a MOON projection head.
1 parent 7d80e2e commit fdb256b

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

examples/server_aggregation/moon/moon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
import moon_client
88
import moon_server
9-
from moon_model import Model as MoonModel
9+
from moon_model_factory import resolve_moon_model
1010

1111

1212
def main():
1313
"""Launch a Plato training session with the MOON algorithm."""
14-
model = MoonModel
14+
model = resolve_moon_model()
1515
client = moon_client.create_client(model=model)
1616
server = moon_server.Server(model=model)
1717
server.run(client)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
Model factory for MOON experiments.
3+
4+
Selects a MOON-compatible model (with projection head) based on the configured
5+
trainer.model_name. Supports LeNet-5 (EMNIST/FEMNIST), ResNet-18 (CIFAR-10),
6+
and VGG-16 (CINIC-10) using the same unified settings as other runs.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import Any
12+
13+
import torch
14+
import torch.nn as nn
15+
import torch.nn.functional as F
16+
17+
from plato.config import Config
18+
from plato.models import resnet, vgg
19+
20+
from moon_model import Model as MoonLeNetModel
21+
22+
23+
def _resolve_model_name() -> str:
24+
trainer = getattr(Config(), "trainer", None)
25+
model_name = getattr(trainer, "model_name", None) if trainer else None
26+
if not isinstance(model_name, str):
27+
return "lenet5"
28+
normalized = model_name.lower().replace("-", "_")
29+
if normalized == "resnet18":
30+
return "resnet_18"
31+
if normalized == "vgg16":
32+
return "vgg_16"
33+
return normalized
34+
35+
36+
def _resolve_num_classes(default: int = 10) -> int:
37+
parameters = getattr(Config(), "parameters", None)
38+
model = getattr(parameters, "model", None) if parameters else None
39+
num_classes = getattr(model, "num_classes", None) if model else None
40+
return int(num_classes) if isinstance(num_classes, int) else default
41+
42+
43+
class MoonLeNetWithProjection(MoonLeNetModel):
44+
"""LeNet-5 MOON model with config-driven class count."""
45+
46+
def __init__(self, num_classes: int | None = None, projection_dim: int = 128, **_):
47+
if num_classes is None:
48+
num_classes = _resolve_num_classes(default=10)
49+
super().__init__(num_classes=num_classes, projection_dim=projection_dim)
50+
51+
52+
class MoonResNetWithProjection(nn.Module):
53+
"""ResNet-18 backbone with a MOON projection head."""
54+
55+
def __init__(self, num_classes: int | None = None, projection_dim: int = 128, **_):
56+
super().__init__()
57+
if num_classes is None:
58+
num_classes = _resolve_num_classes(default=10)
59+
model_name = _resolve_model_name()
60+
if not model_name.startswith("resnet_"):
61+
model_name = "resnet_18"
62+
self.base = resnet.Model.get(model_name=model_name, num_classes=num_classes)
63+
self.projection_head = nn.Sequential(
64+
nn.Linear(512 * resnet.BasicBlock.expansion, projection_dim),
65+
nn.ReLU(inplace=True),
66+
nn.Linear(projection_dim, projection_dim),
67+
)
68+
69+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
70+
out = F.relu(self.base.bn1(self.base.conv1(x)))
71+
out = self.base.layer1(out)
72+
out = self.base.layer2(out)
73+
out = self.base.layer3(out)
74+
out = self.base.layer4(out)
75+
out = F.avg_pool2d(out, 4)
76+
out = out.view(out.size(0), -1)
77+
return out
78+
79+
def forward(self, x: torch.Tensor) -> torch.Tensor:
80+
features = self._encode(x)
81+
logits = self.base.linear(features)
82+
return logits
83+
84+
def forward_with_projection(
85+
self, x: torch.Tensor
86+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
87+
features = self._encode(x)
88+
projection = self.projection_head(features)
89+
projection = F.normalize(projection, dim=1, eps=1e-12)
90+
logits = self.base.linear(features)
91+
return features, projection, logits
92+
93+
94+
class MoonVGGWithProjection(nn.Module):
95+
"""VGG-16 backbone with a MOON projection head."""
96+
97+
def __init__(self, num_classes: int | None = None, projection_dim: int = 128, **_):
98+
super().__init__()
99+
if num_classes is None:
100+
num_classes = _resolve_num_classes(default=10)
101+
model_name = _resolve_model_name()
102+
if not model_name.startswith("vgg_"):
103+
model_name = "vgg_16"
104+
self.base = vgg.Model.get(model_name=model_name, num_classes=num_classes)
105+
self.projection_head = nn.Sequential(
106+
nn.Linear(512, projection_dim),
107+
nn.ReLU(inplace=True),
108+
nn.Linear(projection_dim, projection_dim),
109+
)
110+
111+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
112+
x = self.base.layers(x)
113+
x = nn.AvgPool2d(2)(x)
114+
x = x.view(x.size(0), -1)
115+
return x
116+
117+
def forward(self, x: torch.Tensor) -> torch.Tensor:
118+
features = self._encode(x)
119+
logits = self.base.fc(features)
120+
return logits
121+
122+
def forward_with_projection(
123+
self, x: torch.Tensor
124+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125+
features = self._encode(x)
126+
projection = self.projection_head(features)
127+
projection = F.normalize(projection, dim=1, eps=1e-12)
128+
logits = self.base.fc(features)
129+
return features, projection, logits
130+
131+
132+
def resolve_moon_model() -> Any:
133+
"""Return the MOON-compatible model class for the configured trainer model."""
134+
model_name = _resolve_model_name()
135+
if model_name.startswith("resnet_"):
136+
return MoonResNetWithProjection
137+
if model_name.startswith("vgg_"):
138+
return MoonVGGWithProjection
139+
return MoonLeNetWithProjection

0 commit comments

Comments
 (0)