Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/delm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,38 +579,42 @@ def from_dict(cls, data: Dict[str, Any]) -> "DELMConfig":
if data is None:
data = {}

# Check if this is nested format (from to_dict())
return cls(
schema=data["schema"],
provider=data["provider"],
model=data["model"],
base_url=data["base_url"],
mode=data["mode"],
temperature=data["temperature"],
prompt_template=data["prompt_template"],
system_prompt=data["system_prompt"],
batch_size=data["batch_size"],
max_workers=data["max_workers"],
max_retries=data["max_retries"],
base_delay=data["base_delay"],
rate_limit_tokens=data["rate_limit_tokens"],
rate_limit_requests=data["rate_limit_requests"],
provider=data.get("provider", "openai"),
model=data.get("model", "gpt-4o-mini"),
base_url=data.get("base_url"),
mode=data.get("mode"),
temperature=data.get("temperature", 0.0),
prompt_template=data.get(
"prompt_template",
"Extract the following information from the text:\n\n{variables}\n\nText to analyze:\n{text}",
),
system_prompt=data.get(
"system_prompt", "You are a precise data-extraction assistant."
),
batch_size=data.get("batch_size", 10),
max_workers=data.get("max_workers", 1),
max_retries=data.get("max_retries", 3),
base_delay=data.get("base_delay", 1.0),
rate_limit_tokens=data.get("rate_limit_tokens"),
rate_limit_requests=data.get("rate_limit_requests"),
rate_limit_period_seconds=data.get("rate_limit_period_seconds", 60.0),
track_cost=data["track_cost"],
max_budget=data["max_budget"],
model_input_cost_per_1M_tokens=data["model_input_cost_per_1M_tokens"],
model_output_cost_per_1M_tokens=data["model_output_cost_per_1M_tokens"],
track_cost=data.get("track_cost", True),
max_budget=data.get("max_budget"),
model_input_cost_per_1M_tokens=data.get("model_input_cost_per_1M_tokens"),
model_output_cost_per_1M_tokens=data.get("model_output_cost_per_1M_tokens"),
max_completion_tokens=data.get("max_completion_tokens", 4096),
api_kwargs=data.get("api_kwargs", {}),
target_column=data["target_column"],
drop_target_column=data["drop_target_column"],
splitting_strategy=data["splitting_strategy"],
relevance_scorer=data["relevance_scorer"],
score_filter=data["score_filter"],
cache_backend=data["cache_backend"],
cache_path=data["cache_path"],
cache_max_size_mb=data["cache_max_size_mb"],
cache_synchronous=data["cache_synchronous"],
target_column=data.get("target_column", "text"),
drop_target_column=data.get("drop_target_column", False),
splitting_strategy=data.get("splitting_strategy"),
relevance_scorer=data.get("relevance_scorer"),
score_filter=data.get("score_filter"),
cache_backend=data.get("cache_backend", "sqlite"),
cache_path=data.get("cache_path", ".delm/cache"),
cache_max_size_mb=data.get("cache_max_size_mb", 512),
cache_synchronous=data.get("cache_synchronous", "normal"),
)

@classmethod
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/api_kwargs/test_api_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def test_api_kwargs_serialization_round_trip(self, simple_schema):
restored = DELMConfig.from_dict(config_dict)
assert restored.llm_extraction_cfg.api_kwargs == kwargs

def test_from_dict_uses_defaults_for_missing_optional_fields(self, simple_schema):
config = DELMConfig(schema=simple_schema)
config_dict = config.to_dict()

del config_dict["base_url"]
del config_dict["mode"]
del config_dict["api_kwargs"]
del config_dict["rate_limit_period_seconds"]

restored = DELMConfig.from_dict(config_dict)
assert restored.llm_extraction_cfg.base_url is None
assert restored.llm_extraction_cfg.mode is None
assert restored.llm_extraction_cfg.api_kwargs == {}
assert restored.llm_extraction_cfg.rate_limit_period_seconds == 60.0

def test_api_kwargs_validation_rejects_non_dict(self, simple_schema):
config = DELMConfig(schema=simple_schema, api_kwargs={"valid": True})
config.llm_extraction_cfg.api_kwargs = "not_a_dict"
Expand Down