diff --git a/lib/hooks.ts b/lib/hooks.ts index a48183e..72fda69 100644 --- a/lib/hooks.ts +++ b/lib/hooks.ts @@ -14,11 +14,12 @@ export function createChatMessageTransformHandler( logger: Logger, config: PluginConfig ) { - return async( + return async ( input: {}, output: { messages: WithParts[] } ) => { - checkSession(client, state, logger, output.messages); + await checkSession(client, state, logger, output.messages) + if (state.isSubAgent) { return } diff --git a/lib/state/state.ts b/lib/state/state.ts index fc69883..91e3f92 100644 --- a/lib/state/state.ts +++ b/lib/state/state.ts @@ -4,12 +4,12 @@ import { loadSessionState } from "./persistence" import { getLastUserMessage } from "../messages/utils" import { isSubAgentSession } from "../utils" -export const checkSession = ( +export const checkSession = async ( client: any, state: SessionState, logger: Logger, messages: WithParts[] -) => { +): Promise => { const lastUserMessage = getLastUserMessage(messages) if (!lastUserMessage) { @@ -20,14 +20,11 @@ export const checkSession = ( if (state.sessionId === null || state.sessionId !== lastSessionId) { logger.info(`Session changed: ${state.sessionId} -> ${lastSessionId}`) - ensureSessionInitialized( - client, - state, - lastSessionId, - logger - ).catch((err) => { + try { + await ensureSessionInitialized(client, state, lastSessionId, logger) + } catch (err: any) { logger.error("Failed to initialize session state", { error: err.message }) - } ) + } } } diff --git a/lib/state/tool-cache.ts b/lib/state/tool-cache.ts index 3325367..1c500f0 100644 --- a/lib/state/tool-cache.ts +++ b/lib/state/tool-cache.ts @@ -6,8 +6,6 @@ const MAX_TOOL_CACHE_SIZE = 500 /** * Sync tool parameters from OpenCode's session.messages() API. - * This is the single source of truth for tool parameters, replacing - * format-specific parsing from LLM API requests. */ export async function syncToolCache( state: SessionState, @@ -24,6 +22,8 @@ export async function syncToolCache( continue } + const alreadyPruned = state.prune.toolIds.includes(part.callID) + state.toolParameters.set( part.callID, { @@ -35,7 +35,7 @@ export async function syncToolCache( } ) - if (!config.strategies.pruneTool.protectedTools.includes(part.tool)) { + if (!alreadyPruned && !config.strategies.pruneTool.protectedTools.includes(part.tool)) { state.nudgeCounter++ }