Skip to content

Commit 588a8eb

Browse files
committed
ref: caching, session
1 parent 6c2c6c7 commit 588a8eb

23 files changed

Lines changed: 111 additions & 74 deletions

File tree

fedotllm/agents/automl/automl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
self, config: AppConfig, dataset_path: str | Path, workspace: str | Path
4444
):
4545
self.config = config
46-
self.inference = AIInference(config.llm)
46+
self.inference = AIInference(config.llm, config.session_id)
4747
self.dataset = Dataset.from_path(dataset_path)
4848
self.workspace = Path(workspace)
4949

fedotllm/agents/automl/nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,19 @@ def if_bug(state: AutoMLAgentState, app_config: AppConfig):
216216
def fix_solution(state: AutoMLAgentState, inference: AIInference, dataset: Dataset):
217217
logger.info("Running fix solution")
218218

219+
stdout_lines = state["observation"].stdout.splitlines()
220+
if len(stdout_lines) > 20:
221+
stdout = "\n".join(stdout_lines[:10] + ["..."] + stdout_lines[-10:])
222+
else:
223+
stdout = state["observation"].stdout
224+
219225
fix_prompt = prompts.automl.fix_solution_prompt(
220226
reflection=state["reflection"],
221227
dataset_path=str(dataset.path.absolute()),
222228
code_recent_solution=state["raw_code"],
223229
msg=state["observation"].msg,
224230
stderr=state["observation"].stderr,
225-
stdout=state["observation"].stdout,
231+
stdout=stdout,
226232
)
227233

228234
fixed_solution = inference.query(fix_prompt)

fedotllm/agents/researcher/researcher.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
rewrite_question,
1515
)
1616
from fedotllm.agents.researcher.state import ResearcherAgentState
17+
from fedotllm.configs.schema import AppConfig
1718
from fedotllm.llm import AIInference, LiteLLMEmbeddings
1819

1920
RETRIEVE = "retrieve"
@@ -24,9 +25,9 @@
2425

2526

2627
class ResearcherAgent(Agent):
27-
def __init__(self, inference: AIInference, embeddings: LiteLLMEmbeddings):
28-
self.inference = inference
29-
self.embeddings = embeddings
28+
def __init__(self, config: AppConfig):
29+
self.inference = AIInference(config.llm, config.session_id)
30+
self.embeddings = LiteLLMEmbeddings(config.embeddings)
3031

3132
def create_graph(self):
3233
workflow = StateGraph(ResearcherAgentState)

fedotllm/agents/supervisor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel, Field
1010

1111
from fedotllm.agents.base import Agent, FedotLLMAgentState
12+
from fedotllm.configs.schema import AppConfig
1213
from fedotllm.llm import AIInference
1314
from fedotllm.prompts.supervisor import choose_next_prompt
1415

@@ -26,11 +27,11 @@ class SupervisorState(FedotLLMAgentState):
2627
class SupervisorAgent(Agent):
2728
def __init__(
2829
self,
29-
inference: AIInference,
30+
config: AppConfig,
3031
automl_agent: Runnable,
3132
researcher_agent: Runnable,
3233
):
33-
self.inference = inference
34+
self.inference = AIInference(config.llm, config.session_id)
3435
self.researcher_agent = researcher_agent
3536
self.automl_agent = automl_agent
3637

@@ -50,7 +51,7 @@ def finish_execution(state: SupervisorState):
5051
workflow.add_edge("researcher", "choose_next")
5152
workflow.add_edge("automl", "finish")
5253
workflow.add_edge("finish", END)
53-
return workflow.compile().with_config(config={"run_name": "SupervisorAgent"})
54+
return workflow.compile().with_config(run_name=SupervisorAgent)
5455

5556

5657
class ChooseNext(BaseModel):
@@ -72,6 +73,7 @@ def router_node(
7273
"""
7374

7475
messages = convert_to_openai_messages(state["messages"])
76+
messages = [messages] if isinstance(messages, dict) else messages
7577
messages.append({"role": "user", "content": choose_next_prompt()})
7678

7779
response = inference.query(messages)

fedotllm/configs/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ llm:
1212
model_name: gemini-2.0-flash
1313
base_url: https://generativelanguage.googleapis.com/v1beta/openai/
1414
api_key: ${oc.env:FEDOTLLM_LLM_API_KEY}
15+
caching:
16+
enabled: true
1517
extra_headers:
1618
X-Title: FEDOT.LLM
1719
embeddings:

fedotllm/configs/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import re
3-
from importlib.resources import files
43
from pathlib import Path
54
from typing import List, Optional, Type, TypeVar
65

@@ -9,6 +8,7 @@
98
from pydantic import BaseModel
109

1110
from fedotllm.configs.schema import AppConfig
11+
from fedotllm.constants import PACKAGE_PATH
1212

1313

1414
def _get_default_config_path(
@@ -27,7 +27,7 @@ def _get_default_config_path(
2727
ValueError: If the config file is not found.
2828
"""
2929
try:
30-
config_path = Path(files("fedotllm") / "configs" / f"{presets}.yaml")
30+
config_path = Path(PACKAGE_PATH) / "configs" / f"{presets}.yaml"
3131

3232
if not config_path.exists():
3333
raise ValueError(

fedotllm/configs/ollama.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
llm:
2+
provider: ollama
3+
model_name: "devstral:latest"
4+
base_url:
5+
api_key: "ollama"

fedotllm/configs/schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from pathlib import Path
12
from typing import Any, Dict, Optional
23

34
from pydantic import BaseModel, Field
45

6+
from fedotllm.constants import PACKAGE_PATH
7+
58

69
class TemplatesConfig(BaseModel):
710
code: str
@@ -16,11 +19,17 @@ class AutoMLConfig(BaseModel):
1619
predictor_init_kwargs: dict = Field(default_factory=dict)
1720

1821

22+
class CachingConfig(BaseModel):
23+
enabled: bool = True
24+
dir_path: str = Field(default=str(Path(PACKAGE_PATH) / "cache"))
25+
26+
1927
class LLMConfig(BaseModel):
2028
provider: str = "openai"
2129
model_name: str = "gpt-4o"
2230
base_url: Optional[str] = None
2331
api_key: Optional[str] = None
32+
caching: CachingConfig = Field(default_factory=CachingConfig)
2433
extra_headers: Dict[str, Any] = {}
2534
completion_params: Dict[str, Any] = {}
2635

@@ -45,3 +54,4 @@ class AppConfig(BaseModel):
4554
embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig)
4655
langfuse: LangfuseConfig = Field(default_factory=LangfuseConfig)
4756
automl: AutoMLConfig = Field(default_factory=AutoMLConfig)
57+
session_id: Optional[str] = Field(default=None)

fedotllm/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from importlib.resources import files
2+
from pathlib import Path
3+
4+
PACKAGE_PATH = str(Path(files("fedotllm")))
5+
16
# File formats
27
CSV_SUFFIXES = [".csv"]
38
PARQUET_SUFFIXES = [".parquet", ".pq"]

fedotllm/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def dataset_eda(self):
126126
train_split = self.get_train_split()
127127
df = train_split.data
128128
eda = ""
129-
if df.shape[1] <= 10:
129+
if df.shape[1] <= 20:
130130
eda += "\n===== 1. BASIC INFO =====\n"
131131

132132
buf = io.StringIO()

0 commit comments

Comments
 (0)