forked from Opencode-DCP/opencode-dynamic-context-pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprune.ts
More file actions
134 lines (120 loc) · 4.49 KB
/
prune.ts
File metadata and controls
134 lines (120 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import type { SessionState, WithParts } from "../state"
import type { Logger } from "../logger"
import type { PluginConfig } from "../config"
import { getLastUserMessage, extractParameterKey, buildToolIdList } from "./utils"
import { loadPrompt } from "../prompt"
const PRUNED_TOOL_OUTPUT_REPLACEMENT = '[Output removed to save context - information superseded or no longer needed]'
const NUDGE_STRING = loadPrompt("nudge")
const RECALL_REMINDER_STRING = loadPrompt("recall-reminder")
const buildPrunableToolsList = (
state: SessionState,
config: PluginConfig,
logger: Logger,
messages: WithParts[],
): string => {
const lines: string[] = []
const toolIdList: string[] = buildToolIdList(messages)
state.toolParameters.forEach((toolParameterEntry, toolCallId) => {
if (state.prune.toolIds.includes(toolCallId)) {
return
}
if (config.strategies.pruneTool.protectedTools.includes(toolParameterEntry.tool)) {
return
}
if (toolParameterEntry.compacted) {
return
}
const numericId = toolIdList.indexOf(toolCallId)
const paramKey = extractParameterKey(toolParameterEntry.tool, toolParameterEntry.parameters)
const description = paramKey ? `${toolParameterEntry.tool}, ${paramKey}` : toolParameterEntry.tool
lines.push(`${numericId}: ${description}`)
logger.debug(`Prunable tool found - ID: ${numericId}, Tool: ${toolParameterEntry.tool}, Call ID: ${toolCallId}`)
})
if (lines.length === 0) {
return ""
}
return `<prunable-tools>\nThe following tools have been invoked and are available for pruning. This list does not mandate immediate action. Consider your current goals and the resources you need before discarding valuable tool outputs. Keep the context free of noise.\n${lines.join('\n')}\n</prunable-tools>`
}
export const insertPruneToolContext = (
state: SessionState,
config: PluginConfig,
logger: Logger,
messages: WithParts[]
): void => {
if (!config.strategies.pruneTool.enabled) {
return
}
const lastUserMessage = getLastUserMessage(messages)
if (!lastUserMessage || lastUserMessage.info.role !== 'user') {
return
}
const prunableToolsList = buildPrunableToolsList(state, config, logger, messages)
if (!prunableToolsList) {
return
}
let nudgeString = ""
if (state.nudgeCounter >= config.strategies.pruneTool.nudge.frequency) {
logger.info("Inserting prune nudge message")
nudgeString = "\n" + NUDGE_STRING
}
let recallString = ""
if (config.strategies.pruneTool.recall.enabled && state.recallCounter >= config.strategies.pruneTool.recall.frequency) {
logger.info("Inserting recall reminder")
recallString = "\n" + RECALL_REMINDER_STRING
state.recallCounter = 0
}
const userMessage: WithParts = {
info: {
id: "msg_01234567890123456789012345",
sessionID: lastUserMessage.info.sessionID,
role: "user",
time: { created: Date.now() },
agent: lastUserMessage.info.agent || "build",
model: {
providerID: lastUserMessage.info.model.providerID,
modelID: lastUserMessage.info.model.modelID
}
},
parts: [
{
id: "prt_01234567890123456789012345",
sessionID: lastUserMessage.info.sessionID,
messageID: "msg_01234567890123456789012345",
type: "text",
text: prunableToolsList + nudgeString + recallString,
}
]
}
messages.push(userMessage)
}
export const prune = (
state: SessionState,
logger: Logger,
config: PluginConfig,
messages: WithParts[]
): void => {
pruneToolOutputs(state, logger, messages)
// more prune methods coming here
}
const pruneToolOutputs = (
state: SessionState,
logger: Logger,
messages: WithParts[]
): void => {
for (const msg of messages) {
for (const part of msg.parts) {
if (part.type !== 'tool') {
continue
}
if (!state.prune.toolIds.includes(part.callID)) {
continue
}
if (part.state.status === 'completed') {
part.state.output = PRUNED_TOOL_OUTPUT_REPLACEMENT
}
// if (part.state.status === 'error') {
// part.state.error = PRUNED_TOOL_OUTPUT_REPLACEMENT
// }
}
}
}