Skip to content
Merged

fix #153

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ export function createEventHandler(
return
}

if (event.type === "session.compacted") {
logger.info("Session compaction detected - updating state")
state.lastCompaction = Date.now()
}

if (event.type === "session.status" && event.properties.status.type === "idle") {
if (!config.strategies.onIdle.enabled) {
return
Expand Down
11 changes: 6 additions & 5 deletions lib/messages/prune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { Logger } from "../logger"
import type { PluginConfig } from "../config"
import { loadPrompt } from "../prompt"
import { extractParameterKey, buildToolIdList } from "./utils"
import { getLastUserMessage } from "../shared-utils"
import { getLastUserMessage, isMessageCompacted } from "../shared-utils"
import { UserMessage } from "@opencode-ai/sdk"

const PRUNED_TOOL_INPUT_REPLACEMENT = '[Input removed to save context]'
Expand All @@ -17,7 +17,7 @@ const buildPrunableToolsList = (
messages: WithParts[],
): string => {
const lines: string[] = []
const toolIdList: string[] = buildToolIdList(messages)
const toolIdList: string[] = buildToolIdList(state, messages, logger)

state.toolParameters.forEach((toolParameterEntry, toolCallId) => {
if (state.prune.toolIds.includes(toolCallId)) {
Expand All @@ -26,9 +26,6 @@ const buildPrunableToolsList = (
if (config.strategies.pruneTool.protectedTools.includes(toolParameterEntry.tool)) {
return
}
if (toolParameterEntry.compacted) {
return
}
const numericId = toolIdList.indexOf(toolCallId)
const paramKey = extractParameterKey(toolParameterEntry.tool, toolParameterEntry.parameters)
const description = paramKey ? `${toolParameterEntry.tool}, ${paramKey}` : toolParameterEntry.tool
Expand Down Expand Up @@ -111,6 +108,10 @@ const pruneToolOutputs = (
messages: WithParts[]
): void => {
for (const msg of messages) {
if (isMessageCompacted(state, msg)) {
continue
}

for (const part of msg.parts) {
if (part.type !== 'tool') {
continue
Expand Down
13 changes: 11 additions & 2 deletions lib/messages/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { WithParts } from "../state"
import { Logger } from "../logger"
import { isMessageCompacted } from "../shared-utils"
import type { SessionState, WithParts } from "../state"

/**
* Extracts a human-readable key from tool metadata for display purposes.
Expand Down Expand Up @@ -71,9 +73,16 @@ export const extractParameterKey = (tool: string, parameters: any): string => {
return paramStr.substring(0, 50)
}

export function buildToolIdList(messages: WithParts[]): string[] {
export function buildToolIdList(
state: SessionState,
messages: WithParts[],
logger: Logger
): string[] {
const toolIds: string[] = []
for (const msg of messages) {
if (isMessageCompacted(state, msg)) {
continue
}
if (msg.parts) {
for (const part of msg.parts) {
if (part.type === 'tool' && part.callID && part.tool) {
Expand Down
20 changes: 19 additions & 1 deletion lib/shared-utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import { WithParts } from "./state"
import { Logger } from "./logger"
import { SessionState, WithParts } from "./state"

export const isMessageCompacted = (
state: SessionState,
msg: WithParts
): boolean => {
return msg.info.time.created < state.lastCompaction
}

export const getLastUserMessage = (
messages: WithParts[]
Expand All @@ -11,3 +19,13 @@ export const getLastUserMessage = (
}
return null
}

export const checkForCompaction = (
state: SessionState,
messages: WithParts[],
logger: Logger
): void => {
for (const msg of messages) {

}
}
5 changes: 3 additions & 2 deletions lib/state/persistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export interface PersistedSessionState {
prune: Prune
stats: SessionStats;
lastUpdated: string;
lastCompacted: number
}

const STORAGE_DIR = join(
Expand Down Expand Up @@ -55,6 +56,7 @@ export async function saveSessionState(
prune: sessionState.prune,
stats: sessionState.stats,
lastUpdated: new Date().toISOString(),
lastCompacted: sessionState.lastCompaction
};

const filePath = getSessionFilePath(sessionState.sessionId);
Expand Down Expand Up @@ -99,8 +101,7 @@ export async function loadSessionState(
}

logger.info("Loaded session state from disk", {
sessionId: sessionId,
totalTokensSaved: state.stats.totalPruneTokens
sessionId: sessionId
});

return state;
Expand Down
5 changes: 4 additions & 1 deletion lib/state/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ export function createSessionState(): SessionState {
},
toolParameters: new Map<string, ToolParameterEntry>(),
nudgeCounter: 0,
lastToolPrune: false
lastToolPrune: false,
lastCompaction: 0
}
}

Expand All @@ -58,6 +59,7 @@ export function resetSessionState(state: SessionState): void {
state.toolParameters.clear()
state.nudgeCounter = 0
state.lastToolPrune = false
state.lastCompaction = 0
}

export async function ensureSessionInitialized(
Expand Down Expand Up @@ -95,4 +97,5 @@ export async function ensureSessionInitialized(
pruneTokenCounter: persisted.stats?.pruneTokenCounter || 0,
totalPruneTokens: persisted.stats?.totalPruneTokens || 0,
}
state.lastCompaction = persisted.lastCompacted || 0
}
18 changes: 10 additions & 8 deletions lib/state/tool-cache.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { SessionState, ToolStatus, WithParts } from "./index"
import type { Logger } from "../logger"
import { PluginConfig } from "../config"
import { isMessageCompacted } from "../shared-utils"

const MAX_TOOL_CACHE_SIZE = 1000

Expand All @@ -19,10 +20,17 @@ export async function syncToolCache(
state.nudgeCounter = 0

for (const msg of messages) {
if (isMessageCompacted(state, msg)) {
continue
}

for (const part of msg.parts) {
if (part.type !== "tool" || !part.callID) {
continue
}
if (state.toolParameters.has(part.callID)) {
continue
}

if (part.tool === "prune") {
state.nudgeCounter = 0
Expand All @@ -31,25 +39,19 @@ export async function syncToolCache(
}
state.lastToolPrune = part.tool === "prune"

if (state.toolParameters.has(part.callID)) {
continue
}

state.toolParameters.set(
part.callID,
{
tool: part.tool,
parameters: part.state?.input ?? {},
status: part.state.status as ToolStatus | undefined,
error: part.state.status === "error" ? part.state.error : undefined,
compacted: part.state.status === "completed" && !!part.state.time.compacted,
}
)
logger.info("Cached tool id: " + part.callID)
}
}

// logger.info(`nudgeCounter=${state.nudgeCounter}, lastToolPrune=${state.lastToolPrune}`)

logger.info("Synced cache - size: " + state.toolParameters.size)
trimToolParametersCache(state)
} catch (error) {
logger.warn("Failed to sync tool parameters from OpenCode", {
Expand Down
2 changes: 1 addition & 1 deletion lib/state/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ export interface ToolParameterEntry {
parameters: any
status?: ToolStatus
error?: string
compacted?: boolean
}

export interface SessionStats {
Expand All @@ -32,4 +31,5 @@ export interface SessionState {
toolParameters: Map<string, ToolParameterEntry>
nudgeCounter: number
lastToolPrune: boolean
lastCompaction: number
}
4 changes: 2 additions & 2 deletions lib/strategies/deduplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export const deduplicate = (
}

// Build list of all tool call IDs from messages (chronological order)
const allToolIds = buildToolIdList(messages)
const allToolIds = buildToolIdList(state, messages, logger)
if (allToolIds.length === 0) {
return
}
Expand Down Expand Up @@ -68,7 +68,7 @@ export const deduplicate = (
}
}

state.stats.totalPruneTokens += calculateTokensSaved(messages, newPruneIds)
state.stats.totalPruneTokens += calculateTokensSaved(state, messages, newPruneIds)

if (newPruneIds.length > 0) {
state.prune.toolIds.push(...newPruneIds)
Expand Down
9 changes: 7 additions & 2 deletions lib/strategies/on-idle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { selectModel, ModelInfo } from "../model-selector"
import { saveSessionState } from "../state/persistence"
import { sendUnifiedNotification } from "../ui/notification"
import { calculateTokensSaved, getCurrentParams } from "./utils"
import { isMessageCompacted } from "../shared-utils"

export interface OnIdleResult {
prunedCount: number
Expand All @@ -18,6 +19,7 @@ export interface OnIdleResult {
* Parse messages to extract tool information.
*/
function parseMessages(
state: SessionState,
messages: WithParts[],
toolParametersCache: Map<string, ToolParameterEntry>
): {
Expand All @@ -28,6 +30,9 @@ function parseMessages(
const toolMetadata = new Map<string, ToolParameterEntry>()

for (const msg of messages) {
if (isMessageCompacted(state, msg)) {
continue
}
if (msg.parts) {
for (const part of msg.parts) {
if (part.type === "tool" && part.callID) {
Expand Down Expand Up @@ -224,7 +229,7 @@ export async function runOnIdle(
}

const currentParams = getCurrentParams(messages, logger)
const { toolCallIds, toolMetadata } = parseMessages(messages, state.toolParameters)
const { toolCallIds, toolMetadata } = parseMessages(state, messages, state.toolParameters)

const alreadyPrunedIds = state.prune.toolIds
const unprunedToolCallIds = toolCallIds.filter(id => !alreadyPrunedIds.includes(id))
Expand Down Expand Up @@ -273,7 +278,7 @@ export async function runOnIdle(
const allPrunedIds = [...new Set([...alreadyPrunedIds, ...newlyPrunedIds])]
state.prune.toolIds = allPrunedIds

state.stats.pruneTokenCounter += calculateTokensSaved(messages, newlyPrunedIds)
state.stats.pruneTokenCounter += calculateTokensSaved(state, messages, newlyPrunedIds)

// Build tool metadata map for notification
const prunedToolMetadata = new Map<string, ToolParameterEntry>()
Expand Down
7 changes: 5 additions & 2 deletions lib/strategies/prune-tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ export function createPruneTool(
const { client, state, logger, config, workingDirectory } = ctx
const sessionId = toolCtx.sessionID

logger.info("Prune tool invoked")
logger.info(JSON.stringify(args))

if (!args.ids || args.ids.length === 0) {
logger.debug("Prune tool called but args.ids is empty or undefined: " + JSON.stringify(args))
return "No IDs provided. Check the <prunable-tools> list for available IDs to prune."
Expand Down Expand Up @@ -72,7 +75,7 @@ export function createPruneTool(
const messages: WithParts[] = messagesResponse.data || messagesResponse

const currentParams = getCurrentParams(messages, logger)
const toolIdList: string[] = buildToolIdList(messages)
const toolIdList: string[] = buildToolIdList(state, messages, logger)

// Validate that all numeric IDs are within bounds
if (numericToolIds.some(id => id < 0 || id >= toolIdList.length)) {
Expand Down Expand Up @@ -102,7 +105,7 @@ export function createPruneTool(
}
}

state.stats.pruneTokenCounter += calculateTokensSaved(messages, pruneToolIds)
state.stats.pruneTokenCounter += calculateTokensSaved(state, messages, pruneToolIds)

await sendUnifiedNotification(
client,
Expand Down
8 changes: 6 additions & 2 deletions lib/strategies/utils.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { WithParts } from "../state"
import { SessionState, WithParts } from "../state"
import { UserMessage } from "@opencode-ai/sdk"
import { Logger } from "../logger"
import { encode } from 'gpt-tokenizer'
import { getLastUserMessage } from "../shared-utils"
import { getLastUserMessage, isMessageCompacted } from "../shared-utils"

export function getCurrentParams(
messages: WithParts[],
Expand Down Expand Up @@ -40,12 +40,16 @@ function estimateTokensBatch(texts: string[]): number[] {
* TODO: Make it count message content that are not tool outputs. Currently it ONLY covers tool outputs and errors
*/
export const calculateTokensSaved = (
state: SessionState,
messages: WithParts[],
pruneToolIds: string[]
): number => {
try {
const contents: string[] = []
for (const msg of messages) {
if (isMessageCompacted(state, msg)) {
continue
}
for (const part of msg.parts) {
if (part.type !== 'tool' || !pruneToolIds.includes(part.callID)) {
continue
Expand Down