Skip to content

Commit 6d11774

Browse files
authored
feat: add checkpoint management to TrainingClient (#29)
* feat: add checkpoint management to TrainingClient (#28) Add save_state, load_state, load_state_with_optimizer, and list_checkpoints methods to TrainingClient, enabling checkpoint-based resume training through the SDK. Fixes #28 * fix: align checkpoint API with actual server endpoints - save_state: synchronous POST (not operation), sends {type, path} - load_state: sends checkpoint_id (not path), returns OperationHandle - Checkpoint.from_payload: handle server's "type" field - Add integration test (save/list verified; load blocked by server) * fix(client): send name instead of path in save_state request The server should generate the checkpoint path (including model_id), not the SDK. Changed save_state() to send {"name": ...} instead of {"path": ...} so the server can construct proper namespaced paths like weaver://{model_id}/checkpoints/{name}. Also added checkpoint test scripts for LoRA, FullFT, and baseline. Refs: china-qijizhifeng/weaver-server#106 * refactor(client): load_state accepts path instead of checkpoint_id save_state input is name, output is Checkpoint with server-generated path. load_state/load_state_with_optimizer now accept a path string (weaver:// URI) or Checkpoint object (extracts .path), and send {"path": ...} to the server instead of {"checkpoint_id": ...}. Refs: china-qijizhifeng/weaver-server#106 * chore: move checkpoint test scripts to tests/integration/ * refactor(types): rename checkpoint types to weight/weight_and_optimizer Replace "training" with "weight" and "training_with_optimizer" with "weight_and_optimizer" for clearer checkpoint type semantics. * feat(client): make save_state async with wait parameter save_state now dispatches an async operation via enqueue_operation (same pattern as load_state) instead of a synchronous POST. The server will dispatch a save task to the trainer, which writes weight files to disk. Adds wait parameter with overloads: returns Checkpoint when wait=True (default), OperationHandle when wait=False. Refs: china-qijizhifeng/weaver-server#109 * fix(client): fix save_state to use handle.result() for checkpoint parsing * fix(client): handle nested checkpoint+operation response in save_state Server returns {"checkpoint": {...}, "operation": {...}} instead of a flat operation response. Extract operation for polling and checkpoint data for the result. * fix(client): revert save_state to use enqueue_operation for flat response Server should return a flat Operation (like all other async endpoints) instead of nested {"checkpoint": ..., "operation": ...}. Revert to the standard enqueue_operation pattern. Checkpoint data will come from the operation's response field after completion. Refs: china-qijizhifeng/weaver-server#111
1 parent b37a59b commit 6d11774

10 files changed

Lines changed: 1356 additions & 0 deletions

tests/integration/__init__.py

Whitespace-only changes.

tests/integration/test_baseline.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Nex-AGI. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Baseline: run LoRA (6 steps) and FullFT (6 steps) to record reference losses."""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
import sys
21+
from typing import Any, Dict, List, Sequence
22+
23+
import torch
24+
25+
from weaver import ServiceClient, types
26+
27+
EXAMPLES: List[Dict[str, str]] = [
28+
{"input": "banana split", "output": "anana-bay plit-say"},
29+
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
30+
{"input": "donut shop", "output": "onut-day op-shay"},
31+
{"input": "pickle jar", "output": "ickle-pay ar-jay"},
32+
{"input": "space exploration", "output": "ace-spay exploration-way"},
33+
{"input": "rubber duck", "output": "ubber-ray uck-day"},
34+
{"input": "coding wizard", "output": "oding-cay izard-way"},
35+
]
36+
37+
38+
def process_example(example: Dict[str, str], tokenizer) -> types.Datum:
39+
prompt = f"English: {example['input']}\nPig Latin:"
40+
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
41+
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
42+
tokens = prompt_tokens + completion_tokens
43+
weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)
44+
input_tokens = tokens[:-1]
45+
target_tokens = tokens[1:]
46+
weights = weights[1:]
47+
return types.Datum(
48+
model_input=types.ModelInput.from_ints(input_tokens),
49+
loss_fn_inputs={
50+
"target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
51+
"weights": torch.tensor(weights, dtype=torch.float32),
52+
},
53+
)
54+
55+
56+
def _extract_logprobs(output: Dict[str, Any]) -> torch.Tensor:
57+
value = output.get("logprobs") or output.get("Logprobs")
58+
if isinstance(value, dict):
59+
value = value.get("data")
60+
if value is None:
61+
raise ValueError("Missing logprobs in forward/backward output")
62+
return torch.as_tensor(value, dtype=torch.float32)
63+
64+
65+
def compute_loss(
66+
fwdbwd_result: Dict[str, Any],
67+
processed_examples: Sequence[types.Datum],
68+
) -> float:
69+
outputs = fwdbwd_result.get("result", {}).get("loss_fn_outputs") or []
70+
logprobs = torch.cat([_extract_logprobs(o) for o in outputs], dim=0)
71+
weights = torch.cat([ex.loss_fn_inputs["weights"] for ex in processed_examples], dim=0)
72+
return float(-torch.dot(logprobs, weights) / weights.sum())
73+
74+
75+
def run_baseline(training_mode: str | None, lr: float, steps: int = 6) -> List[float]:
76+
with ServiceClient(api_key=os.getenv("WEAVER_API_KEY")) as client:
77+
kwargs: Dict[str, Any] = {"base_model": "Qwen/Qwen3-8B"}
78+
if training_mode is not None:
79+
kwargs["training_mode"] = training_mode
80+
tc = client.create_model(**kwargs)
81+
tokenizer = tc.get_tokenizer()
82+
data = [process_example(ex, tokenizer) for ex in EXAMPLES]
83+
84+
adam = types.AdamParams(learning_rate=lr)
85+
losses: List[float] = []
86+
for step in range(steps):
87+
result = tc.forward_backward(data, "cross_entropy", wait=True)
88+
_ = tc.optim_step(adam, wait=True)
89+
loss = compute_loss(result, data)
90+
losses.append(loss)
91+
print(f" Step {step}: loss/token={loss:.6f}")
92+
return losses
93+
94+
95+
def main() -> None:
96+
print("=" * 60)
97+
print("BASELINE: LoRA 6 steps (lr=1e-4)")
98+
print("=" * 60)
99+
lora_losses = run_baseline(None, lr=1e-4, steps=6)
100+
101+
print()
102+
print("=" * 60)
103+
print("BASELINE: FullFT 6 steps (lr=1e-5)")
104+
print("=" * 60)
105+
fullft_losses = run_baseline("full_ft", lr=1e-5, steps=6)
106+
107+
print()
108+
print("=" * 60)
109+
print("SUMMARY")
110+
print("=" * 60)
111+
print("LoRA losses:", [f"{l:.6f}" for l in lora_losses])
112+
print("FullFT losses:", [f"{l:.6f}" for l in fullft_losses])
113+
114+
115+
if __name__ == "__main__":
116+
main()
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Nex-AGI. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
End-to-end integration test for checkpoint management.
17+
18+
Verifies save_state / load_state round-trip against a live Weaver server
19+
by comparing forward-backward loss values:
20+
21+
1. Train 3 steps → save checkpoint → compute loss (A)
22+
2. Train 3 more steps → compute loss (B, should differ from A)
23+
3. load_state back to checkpoint → compute loss (C)
24+
4. Assert C == A (checkpoint restored correctly)
25+
26+
Usage:
27+
WEAVER_API_KEY=sk-... python tests/integration/test_checkpoint_e2e.py
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import logging
33+
import sys
34+
from typing import Any, Dict, List
35+
36+
import torch
37+
38+
from weaver import ServiceClient, types
39+
from weaver.types.checkpoint import Checkpoint
40+
41+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
42+
log = logging.getLogger(__name__)
43+
44+
BASE_URL = "https://weaver-console.nex-agi.cn"
45+
BASE_MODEL = "Qwen/Qwen3-8B"
46+
NUM_TRAINING_STEPS = 3
47+
48+
# Pig Latin training examples (same style as examples/pig_latin.py)
49+
EXAMPLES: List[Dict[str, str]] = [
50+
{"input": "banana split", "output": "anana-bay plit-say"},
51+
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
52+
{"input": "donut shop", "output": "onut-day op-shay"},
53+
{"input": "pickle jar", "output": "ickle-pay ar-jay"},
54+
{"input": "space exploration", "output": "ace-spay exploration-way"},
55+
{"input": "rubber duck", "output": "ubber-ray uck-day"},
56+
{"input": "coding wizard", "output": "oding-cay izard-way"},
57+
]
58+
59+
60+
def process_example(example: Dict[str, str], tokenizer) -> types.Datum:
61+
"""Build a cross-entropy datum from a Pig Latin example."""
62+
prompt = f"English: {example['input']}\nPig Latin:"
63+
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
64+
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
65+
66+
tokens = prompt_tokens + completion_tokens
67+
weights = [0.0] * len(prompt_tokens) + [1.0] * len(completion_tokens)
68+
69+
input_tokens = tokens[:-1]
70+
target_tokens = tokens[1:]
71+
weights = weights[1:]
72+
73+
return types.Datum(
74+
model_input=types.ModelInput.from_ints(input_tokens),
75+
loss_fn_inputs={
76+
"target_tokens": torch.tensor(target_tokens, dtype=torch.int64),
77+
"weights": torch.tensor(weights, dtype=torch.float32),
78+
},
79+
)
80+
81+
82+
def _extract_logprobs(output: Dict[str, Any]) -> torch.Tensor:
83+
"""Extract logprobs from a forward-backward output."""
84+
value = output.get("logprobs") or output.get("Logprobs")
85+
if isinstance(value, dict):
86+
value = value.get("data")
87+
if value is None:
88+
raise ValueError("Missing logprobs in forward/backward output")
89+
return torch.as_tensor(value, dtype=torch.float32)
90+
91+
92+
def compute_loss(fwdbwd_result: Dict[str, Any], data: list[types.Datum]) -> float:
93+
"""Compute weighted cross-entropy loss from forward-backward result."""
94+
outputs = fwdbwd_result.get("result", {}).get("loss_fn_outputs") or []
95+
logprobs = torch.cat([_extract_logprobs(output) for output in outputs], dim=0)
96+
weights = torch.cat([d.loss_fn_inputs["weights"] for d in data], dim=0)
97+
loss = -torch.dot(logprobs, weights) / weights.sum()
98+
return float(loss)
99+
100+
101+
def train_steps(tc, data, adam, n: int) -> float:
102+
"""Run n train steps, return the loss of the last step."""
103+
last_loss = 0.0
104+
for i in range(n):
105+
result = tc.forward_backward(data, "cross_entropy")
106+
tc.optim_step(adam)
107+
last_loss = compute_loss(result, data)
108+
log.info(" Step %d: loss=%.6f", i + 1, last_loss)
109+
return last_loss
110+
111+
112+
def eval_loss(tc, data) -> float:
113+
"""Run forward-backward without optim_step to get the current loss."""
114+
result = tc.forward_backward(data, "cross_entropy")
115+
# We need to "undo" this fwd-bwd by running optim_step with lr=0,
116+
# but actually in Weaver the gradients from forward_backward are consumed
117+
# by optim_step, so we need a different approach.
118+
# Instead, just note the loss and run a no-op optim_step to consume grads.
119+
loss = compute_loss(result, data)
120+
tc.optim_step(types.AdamParams(learning_rate=0.0))
121+
return loss
122+
123+
124+
def main() -> int:
125+
log.info("=== Checkpoint E2E Integration Test ===")
126+
log.info("Connecting to %s with model %s", BASE_URL, BASE_MODEL)
127+
128+
with ServiceClient(base_url=BASE_URL) as service:
129+
# 1. Create model
130+
log.info("Step 1: Creating training model...")
131+
tc = service.create_model(base_model=BASE_MODEL)
132+
log.info("Model created: %s", tc.model_id)
133+
134+
tokenizer = tc.tokenizer
135+
data = [process_example(ex, tokenizer) for ex in EXAMPLES]
136+
adam = types.AdamParams(learning_rate=1e-4)
137+
138+
# 2. Train 3 steps
139+
log.info("Step 2: Training %d steps...", NUM_TRAINING_STEPS)
140+
train_steps(tc, data, adam, NUM_TRAINING_STEPS)
141+
142+
# 3. Save checkpoint
143+
log.info("Step 3: Saving checkpoint...")
144+
ckpt = tc.save_state(name="after-3-steps")
145+
log.info(" Checkpoint saved: id=%s path=%s", ckpt.id, ckpt.path)
146+
assert ckpt.id, "save_state() returned empty checkpoint id"
147+
assert ckpt.path, "save_state() returned empty checkpoint path"
148+
149+
# 4. Eval loss at checkpoint (A)
150+
log.info("Step 4: Evaluating loss at checkpoint...")
151+
loss_at_ckpt = eval_loss(tc, data)
152+
log.info(" Loss (A) at checkpoint: %.6f", loss_at_ckpt)
153+
154+
# 5. Train more steps to drift weights
155+
log.info("Step 5: Training %d more steps to drift weights...", NUM_TRAINING_STEPS)
156+
train_steps(tc, data, adam, NUM_TRAINING_STEPS)
157+
158+
# 6. Eval loss after drift (B)
159+
log.info("Step 6: Evaluating loss after drift...")
160+
loss_after_drift = eval_loss(tc, data)
161+
log.info(" Loss (B) after drift: %.6f", loss_after_drift)
162+
163+
# Verify loss changed (training had effect)
164+
assert (
165+
loss_at_ckpt != loss_after_drift
166+
), f"Loss didn't change after training: {loss_at_ckpt} == {loss_after_drift}"
167+
log.info(
168+
" Confirmed: loss changed after more training (%.6f -> %.6f)",
169+
loss_at_ckpt,
170+
loss_after_drift,
171+
)
172+
173+
# 7. list_checkpoints
174+
log.info("Step 7: Listing checkpoints...")
175+
checkpoints = tc.list_checkpoints()
176+
log.info(" Found %d checkpoint(s):", len(checkpoints))
177+
for c in checkpoints:
178+
log.info(" - id=%s path=%s type=%s", c.id, c.path, c.checkpoint_type)
179+
assert len(checkpoints) >= 1, "Expected at least 1 checkpoint"
180+
assert any(c.id == ckpt.id for c in checkpoints), f"Saved checkpoint {ckpt.id} not in list"
181+
182+
# 8. Restore checkpoint
183+
log.info("Step 8: Restoring model to checkpoint...")
184+
tc.load_state(ckpt)
185+
log.info(" Model restored.")
186+
187+
# 9. Eval loss after restore (C) — should match (A)
188+
log.info("Step 9: Evaluating loss after restore...")
189+
loss_after_restore = eval_loss(tc, data)
190+
log.info(" Loss (C) after restore: %.6f", loss_after_restore)
191+
192+
# 10. Compare
193+
log.info("=== Results ===")
194+
log.info("(A) Loss at checkpoint: %.6f", loss_at_ckpt)
195+
log.info("(B) Loss after drift: %.6f", loss_after_drift)
196+
log.info("(C) Loss after restore: %.6f", loss_after_restore)
197+
198+
# Allow small floating-point tolerance
199+
tolerance = 1e-4
200+
diff_ac = abs(loss_at_ckpt - loss_after_restore)
201+
diff_bc = abs(loss_after_drift - loss_after_restore)
202+
203+
if diff_ac < tolerance:
204+
log.info(
205+
"PASS: Restored loss (C) matches checkpoint loss (A) "
206+
"within tolerance (diff=%.8f)",
207+
diff_ac,
208+
)
209+
elif diff_bc > tolerance:
210+
log.info(
211+
"PASS: Restored loss (C) differs from drifted loss (B) — "
212+
"load_state undid the drift. "
213+
"(diff A-C=%.6f, diff B-C=%.6f)",
214+
diff_ac,
215+
diff_bc,
216+
)
217+
else:
218+
log.error(
219+
"FAIL: Restored loss (C=%.6f) matches drifted loss (B=%.6f) — "
220+
"load_state did NOT restore the checkpoint!",
221+
loss_after_restore,
222+
loss_after_drift,
223+
)
224+
return 1
225+
226+
log.info("=== Checkpoint E2E test passed ===")
227+
return 0
228+
229+
230+
if __name__ == "__main__":
231+
sys.exit(main())

0 commit comments

Comments
 (0)