Skip to content

Commit 2f23b0f

Browse files
committed
fix(training-operator)
harden HuggingFace training_parameters parsing Signed-off-by: Ayush-kathil <kathilshiva@gmail.com>
1 parent e4705d7 commit 2f23b0f

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
import logging
6+
from typing import Any, Optional
7+
8+
try:
9+
from transformers import TrainingArguments
10+
except ImportError: # pragma: no cover - exercised only when transformers is absent.
11+
class TrainingArguments: # type: ignore[no-redef]
12+
def __init__(self, *args: Any, **kwargs: Any) -> None:
13+
raise ImportError(
14+
"transformers is required to construct HuggingFace TrainingArguments."
15+
)
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
DEFAULT_OUTPUT_DIR = "./output"
21+
22+
23+
def parse_training_args(raw: Optional[str]) -> dict[str, Any]:
24+
"""Parse a JSON string into a TrainingArguments configuration."""
25+
26+
if raw is None:
27+
return {}
28+
29+
if not isinstance(raw, str):
30+
raise ValueError(
31+
"training_parameters must be a JSON string or None; got "
32+
f"{type(raw).__name__}."
33+
)
34+
35+
normalized = raw.strip()
36+
if not normalized:
37+
return {}
38+
39+
try:
40+
parsed = json.loads(normalized)
41+
except json.JSONDecodeError as exc:
42+
raise ValueError(
43+
"Invalid JSON in training_parameters. Provide a JSON object string, for "
44+
f"example '{{\"output_dir\": \"./output\"}}'. Received: {raw!r}. "
45+
f"JSON error: {exc.msg} at line {exc.lineno}, column {exc.colno}."
46+
) from exc
47+
48+
if not isinstance(parsed, dict):
49+
raise ValueError(
50+
"training_parameters must decode to a JSON object. Received "
51+
f"{type(parsed).__name__}: {parsed!r}."
52+
)
53+
54+
invalid_keys = [key for key in parsed.keys() if not isinstance(key, str) or not key.strip()]
55+
if invalid_keys:
56+
raise ValueError(
57+
"training_parameters contains invalid keys. JSON object keys must be non-empty "
58+
f"strings. Invalid keys: {invalid_keys!r}."
59+
)
60+
61+
return parsed
62+
63+
64+
def build_training_arguments(raw: Optional[str]) -> TrainingArguments:
65+
logger.info("Raw training_parameters payload: %r", raw)
66+
parsed_config = parse_training_args(raw)
67+
68+
if not parsed_config:
69+
logger.info(
70+
"training_parameters is empty or missing; using default TrainingArguments with output_dir=%s",
71+
DEFAULT_OUTPUT_DIR,
72+
)
73+
return TrainingArguments(output_dir=DEFAULT_OUTPUT_DIR)
74+
75+
logger.info("Parsed training_parameters config: %s", json.dumps(parsed_config, sort_keys=True))
76+
try:
77+
return TrainingArguments(**parsed_config)
78+
except Exception as exc:
79+
logger.error(
80+
"Failed to create TrainingArguments from parsed training_parameters: %s",
81+
json.dumps(parsed_config, sort_keys=True),
82+
exc_info=True,
83+
)
84+
raise ValueError(
85+
"Failed to initialize TrainingArguments from training_parameters. "
86+
"Check the JSON keys and values, and ensure they match the HuggingFace "
87+
f"TrainingArguments signature. Parsed config: {json.dumps(parsed_config, sort_keys=True)}"
88+
) from exc
89+
90+
91+
def _build_parser() -> argparse.ArgumentParser:
92+
parser = argparse.ArgumentParser(description="Run a HuggingFace training job.")
93+
parser.add_argument(
94+
"--training_parameters",
95+
type=str,
96+
default="{}",
97+
help="JSON object used to initialize HuggingFace TrainingArguments.",
98+
)
99+
return parser
100+
101+
102+
def main() -> None:
103+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
104+
parser = _build_parser()
105+
args = parser.parse_args()
106+
107+
training_args = build_training_arguments(args.training_parameters)
108+
logger.info("TrainingArguments initialized successfully: %s", training_args)
109+
110+
# Replace this with the actual training workflow used by the example.
111+
logger.info("Trainer entrypoint completed parsing and initialization only.")
112+
113+
114+
if __name__ == "__main__":
115+
main()
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
9+
SCRIPT_PATH = (
10+
Path(__file__).resolve().parents[4]
11+
/ "examples"
12+
/ "v1beta1"
13+
/ "kubeflow-training-operator"
14+
/ "hf_llm_training.py"
15+
)
16+
17+
18+
class DummyTrainingArguments:
19+
def __init__(self, **kwargs):
20+
self.kwargs = kwargs
21+
22+
23+
def load_module():
24+
spec = importlib.util.spec_from_file_location("hf_llm_training", SCRIPT_PATH)
25+
module = importlib.util.module_from_spec(spec)
26+
assert spec.loader is not None
27+
spec.loader.exec_module(module)
28+
return module
29+
30+
31+
def test_parse_training_args_empty_string_returns_empty_dict():
32+
module = load_module()
33+
34+
assert module.parse_training_args("") == {}
35+
36+
37+
def test_parse_training_args_none_returns_empty_dict():
38+
module = load_module()
39+
40+
assert module.parse_training_args(None) == {}
41+
42+
43+
def test_parse_training_args_whitespace_returns_empty_dict():
44+
module = load_module()
45+
46+
assert module.parse_training_args(" \n\t ") == {}
47+
48+
49+
def test_parse_training_args_valid_json_returns_dict():
50+
module = load_module()
51+
52+
assert module.parse_training_args('{"output_dir": "./output", "learning_rate": 0.0001}') == {
53+
"output_dir": "./output",
54+
"learning_rate": 0.0001,
55+
}
56+
57+
58+
def test_parse_training_args_invalid_json_raises_value_error():
59+
module = load_module()
60+
61+
with pytest.raises(ValueError, match="Invalid JSON in training_parameters"):
62+
module.parse_training_args("{invalid-json")
63+
64+
65+
def test_parse_training_args_malformed_keys_raises_value_error():
66+
module = load_module()
67+
68+
with pytest.raises(ValueError, match="invalid keys"):
69+
module.parse_training_args('{"": 1, " ": 2}')
70+
71+
72+
def test_build_training_arguments_uses_default_when_empty(monkeypatch):
73+
module = load_module()
74+
monkeypatch.setattr(module, "TrainingArguments", DummyTrainingArguments)
75+
76+
training_args = module.build_training_arguments("")
77+
78+
assert training_args.kwargs == {"output_dir": "./output"}
79+
80+
81+
def test_build_training_arguments_passes_valid_config(monkeypatch):
82+
module = load_module()
83+
monkeypatch.setattr(module, "TrainingArguments", DummyTrainingArguments)
84+
85+
training_args = module.build_training_arguments('{"output_dir": "./tmp", "num_train_epochs": 3}')
86+
87+
assert training_args.kwargs == {"output_dir": "./tmp", "num_train_epochs": 3}

0 commit comments

Comments
 (0)