diff --git a/src/delm/config.py b/src/delm/config.py index c4a0d3c..9ef62dd 100644 --- a/src/delm/config.py +++ b/src/delm/config.py @@ -579,38 +579,42 @@ def from_dict(cls, data: Dict[str, Any]) -> "DELMConfig": if data is None: data = {} - # Check if this is nested format (from to_dict()) return cls( schema=data["schema"], - provider=data["provider"], - model=data["model"], - base_url=data["base_url"], - mode=data["mode"], - temperature=data["temperature"], - prompt_template=data["prompt_template"], - system_prompt=data["system_prompt"], - batch_size=data["batch_size"], - max_workers=data["max_workers"], - max_retries=data["max_retries"], - base_delay=data["base_delay"], - rate_limit_tokens=data["rate_limit_tokens"], - rate_limit_requests=data["rate_limit_requests"], + provider=data.get("provider", "openai"), + model=data.get("model", "gpt-4o-mini"), + base_url=data.get("base_url"), + mode=data.get("mode"), + temperature=data.get("temperature", 0.0), + prompt_template=data.get( + "prompt_template", + "Extract the following information from the text:\n\n{variables}\n\nText to analyze:\n{text}", + ), + system_prompt=data.get( + "system_prompt", "You are a precise data-extraction assistant." + ), + batch_size=data.get("batch_size", 10), + max_workers=data.get("max_workers", 1), + max_retries=data.get("max_retries", 3), + base_delay=data.get("base_delay", 1.0), + rate_limit_tokens=data.get("rate_limit_tokens"), + rate_limit_requests=data.get("rate_limit_requests"), rate_limit_period_seconds=data.get("rate_limit_period_seconds", 60.0), - track_cost=data["track_cost"], - max_budget=data["max_budget"], - model_input_cost_per_1M_tokens=data["model_input_cost_per_1M_tokens"], - model_output_cost_per_1M_tokens=data["model_output_cost_per_1M_tokens"], + track_cost=data.get("track_cost", True), + max_budget=data.get("max_budget"), + model_input_cost_per_1M_tokens=data.get("model_input_cost_per_1M_tokens"), + model_output_cost_per_1M_tokens=data.get("model_output_cost_per_1M_tokens"), max_completion_tokens=data.get("max_completion_tokens", 4096), api_kwargs=data.get("api_kwargs", {}), - target_column=data["target_column"], - drop_target_column=data["drop_target_column"], - splitting_strategy=data["splitting_strategy"], - relevance_scorer=data["relevance_scorer"], - score_filter=data["score_filter"], - cache_backend=data["cache_backend"], - cache_path=data["cache_path"], - cache_max_size_mb=data["cache_max_size_mb"], - cache_synchronous=data["cache_synchronous"], + target_column=data.get("target_column", "text"), + drop_target_column=data.get("drop_target_column", False), + splitting_strategy=data.get("splitting_strategy"), + relevance_scorer=data.get("relevance_scorer"), + score_filter=data.get("score_filter"), + cache_backend=data.get("cache_backend", "sqlite"), + cache_path=data.get("cache_path", ".delm/cache"), + cache_max_size_mb=data.get("cache_max_size_mb", 512), + cache_synchronous=data.get("cache_synchronous", "normal"), ) @classmethod diff --git a/tests/unit/api_kwargs/test_api_kwargs.py b/tests/unit/api_kwargs/test_api_kwargs.py index 5bcb4a4..a7a1a2c 100644 --- a/tests/unit/api_kwargs/test_api_kwargs.py +++ b/tests/unit/api_kwargs/test_api_kwargs.py @@ -51,6 +51,21 @@ def test_api_kwargs_serialization_round_trip(self, simple_schema): restored = DELMConfig.from_dict(config_dict) assert restored.llm_extraction_cfg.api_kwargs == kwargs + def test_from_dict_uses_defaults_for_missing_optional_fields(self, simple_schema): + config = DELMConfig(schema=simple_schema) + config_dict = config.to_dict() + + del config_dict["base_url"] + del config_dict["mode"] + del config_dict["api_kwargs"] + del config_dict["rate_limit_period_seconds"] + + restored = DELMConfig.from_dict(config_dict) + assert restored.llm_extraction_cfg.base_url is None + assert restored.llm_extraction_cfg.mode is None + assert restored.llm_extraction_cfg.api_kwargs == {} + assert restored.llm_extraction_cfg.rate_limit_period_seconds == 60.0 + def test_api_kwargs_validation_rejects_non_dict(self, simple_schema): config = DELMConfig(schema=simple_schema, api_kwargs={"valid": True}) config.llm_extraction_cfg.api_kwargs = "not_a_dict"