diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index b69fd8be0..84fd990cb 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -121,7 +121,12 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req return resp, fmt.Errorf("kiro: access token not found in auth") } if profileArn == "" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + // Only warn if not using builder-id auth (which doesn't need profileArn) + if auth == nil || auth.Metadata == nil { + log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") + } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) @@ -161,10 +166,19 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req currentOrigin = "CLI" } - kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + // Determine if profileArn should be included based on auth method + // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) + effectiveProfileArn := profileArn + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + effectiveProfileArn = "" // Don't include profileArn for builder-id auth + } + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) // Execute with retry on 401/403 and 429 (quota exhausted) - resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, to, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) return resp, err } @@ -301,9 +315,18 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } } if len(content) > 0 { - usageInfo.OutputTokens = int64(len(content) / 4) + // Use tiktoken for more accurate output token calculation + if enc, encErr := tokenizerForModel(req.Model); encErr == nil { + if tokenCount, countErr := enc.Count(content); countErr == nil { + usageInfo.OutputTokens = int64(tokenCount) + } + } + // Fallback to character count estimation if tiktoken fails if usageInfo.OutputTokens == 0 { - usageInfo.OutputTokens = 1 + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } } } usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens @@ -330,7 +353,12 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, fmt.Errorf("kiro: access token not found in auth") } if profileArn == "" { - log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + // Only warn if not using builder-id auth (which doesn't need profileArn) + if auth == nil || auth.Metadata == nil { + log.Debugf("kiro: profile ARN not found in auth (may be normal for builder-id)") + } else if authMethod, ok := auth.Metadata["auth_method"].(string); !ok || authMethod != "builder-id" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) @@ -370,10 +398,19 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut currentOrigin = "CLI" } - kiroPayload := e.buildKiroPayload(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly) + // Determine if profileArn should be included based on auth method + // profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO) + effectiveProfileArn := profileArn + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + effectiveProfileArn = "" // Don't include profileArn for builder-id auth + } + } + + kiroPayload := e.buildKiroPayload(body, kiroModelID, effectiveProfileArn, currentOrigin, isAgentic, isChatOnly) // Execute stream with retry on 401/403 and 429 (quota exhausted) - return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, profileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, kiroPayload, body, from, reporter, currentOrigin, kiroModelID, isAgentic, isChatOnly) } // executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. @@ -491,6 +528,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox go func(resp *http.Response) { defer close(out) + defer func() { + if r := recover(); r != nil { + log.Errorf("kiro: panic in stream handler: %v", r) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + } + }() defer func() { if errClose := resp.Body.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) @@ -587,10 +630,10 @@ type kiroPayload struct { } type kiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field ConversationID string `json:"conversationId"` - History []kiroHistoryMessage `json:"history"` CurrentMessage kiroCurrentMessage `json:"currentMessage"` - ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + History []kiroHistoryMessage `json:"history,omitempty"` // Only include when non-empty } type kiroCurrentMessage struct { @@ -627,9 +670,9 @@ type kiroUserInputMessageContext struct { } type kiroToolResult struct { - ToolUseID string `json:"toolUseId"` Content []kiroTextContent `json:"content"` Status string `json:"status"` + ToolUseID string `json:"toolUseId"` } type kiroTextContent struct { @@ -735,7 +778,9 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, var currentUserMsg *kiroUserInputMessage var currentToolResults []kiroToolResult - messagesArray := messages.Array() + // Merge adjacent messages with the same role before processing + // This reduces API call complexity and improves compatibility + messagesArray := mergeAdjacentMessages(messages.Array()) for i, msg := range messagesArray { role := msg.Get("role").String() isLastMessage := i == len(messagesArray)-1 @@ -746,6 +791,14 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, currentUserMsg = &userMsg currentToolResults = toolResults } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages too + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } // For history messages, embed tool results in context if len(toolResults) > 0 { userMsg.UserInputMessageContext = &kiroUserInputMessageContext{ @@ -758,9 +811,24 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, } } else if role == "assistant" { assistantMsg := e.buildAssistantMessageStruct(msg) - history = append(history, kiroHistoryMessage{ - AssistantResponseMessage: &assistantMsg, - }) + // If this is the last message and it's an assistant message, + // we need to add it to history and create a "Continue" user message + // because Kiro API requires currentMessage to be userInputMessage type + if isLastMessage { + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &kiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, kiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } } } @@ -777,7 +845,35 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // Add the actual user message contentBuilder.WriteString(currentUserMsg.Content) - currentUserMsg.Content = contentBuilder.String() + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty, even when toolResults are present + // If content is empty or only whitespace, provide a default message + if strings.TrimSpace(finalContent) == "" { + if len(currentToolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro: content was empty, using default: %s", finalContent) + } + currentUserMsg.Content = finalContent + + // Deduplicate currentToolResults before adding to context + // Kiro API does not accept duplicate toolUseIds + if len(currentToolResults) > 0 { + seenIDs := make(map[string]bool) + uniqueToolResults := make([]kiroToolResult, 0, len(currentToolResults)) + for _, tr := range currentToolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + uniqueToolResults = append(uniqueToolResults, tr) + } else { + log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID) + } + } + currentToolResults = uniqueToolResults + } // Build userInputMessageContext with tools and tool results if len(kiroTools) > 0 || len(currentToolResults) > 0 { @@ -805,21 +901,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, }} } + // Build payload with correct field order (matches struct definition) + // Note: history is omitempty, so nil/empty slice won't be serialized payload := kiroPayload{ ConversationState: kiroConversationState{ + ChatTriggerType: "MANUAL", // Required by Kiro API - must be first ConversationID: uuid.New().String(), - History: history, CurrentMessage: currentMessage, - ChatTriggerType: "MANUAL", // Required by Kiro API + History: history, // Will be omitted if empty due to omitempty tag }, ProfileArn: profileArn, } - // Ensure history is not nil (empty array) - if payload.ConversationState.History == nil { - payload.ConversationState.History = []kiroHistoryMessage{} - } - result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro: failed to marshal payload: %v", err) @@ -830,11 +923,15 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, // buildUserMessageStruct builds a user message and extracts tool results // origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// IMPORTANT: Kiro API does not accept duplicate toolUseIds, so we deduplicate here. func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin string) (kiroUserInputMessage, []kiroToolResult) { content := msg.Get("content") var contentBuilder strings.Builder var toolResults []kiroToolResult var images []kiroImage + + // Track seen toolUseIds to deduplicate - Kiro API rejects duplicate toolUseIds + seenToolUseIDs := make(map[string]bool) if content.IsArray() { for _, part := range content.Array() { @@ -864,6 +961,14 @@ func (e *KiroExecutor) buildUserMessageStruct(msg gjson.Result, modelID, origin case "tool_result": // Extract tool result for API toolUseID := part.Get("tool_use_id").String() + + // Skip duplicate toolUseIds - Kiro API does not accept duplicates + if seenToolUseIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID) + continue + } + seenToolUseIDs[toolUseID] = true + isError := part.Get("is_error").Bool() resultContent := part.Get("content") @@ -1001,6 +1106,12 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err) } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + // Extract event type from headers eventType := e.extractEventType(remaining[:headersLen+4]) @@ -1018,6 +1129,37 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, continue } + // DIAGNOSTIC: Log all received event types for debugging + log.Debugf("kiro: parseEventStream received event type: %s", eventType) + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("kiro: parseEventStream event payload: %s", string(payload)) + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg) + } + // Handle different event types switch eventType { case "assistantResponseEvent": @@ -1231,8 +1373,9 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs // streamToChannel converts AWS Event Stream to channel-based streaming. // Supports tool calling - emits tool_use content blocks when tools are used. // Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) { - reader := bufio.NewReader(body) + reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers var totalUsage usage.Detail var hasToolUses bool // Track if any tool uses were emitted @@ -1240,6 +1383,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out processedIDs := make(map[string]bool) var currentToolUse *toolUseState + // Duplicate content detection - tracks last content event to filter duplicates + // Based on AIClient-2-API implementation for Kiro + var lastContentEvent string + + // Streaming token calculation - accumulate content for real-time token counting + // Based on AIClient-2-API implementation + var accumulatedContent strings.Builder + accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + // Translator param for maintaining tool call state across streaming events // IMPORTANT: This must persist across all TranslateStream calls var translatorParam any @@ -1279,6 +1431,51 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out prelude := make([]byte, 8) _, err := io.ReadFull(reader, prelude) if err == io.EOF { + // Flush any incomplete tool use before ending stream + if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] { + log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID) + fullInput := currentToolUse.inputBuffer.String() + repairedJSON := repairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) + finalInput = make(map[string]interface{}) + } + + processedIDs[currentToolUse.toolUseID] = true + contentBlockIndex++ + + // Send tool_use content block + blockStart := e.buildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.toolUseID, currentToolUse.name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send tool input as delta + inputBytes, _ := json.Marshal(finalInput) + inputDelta := e.buildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Close block + blockStop := e.buildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + hasToolUses = true + currentToolUse = nil + } break } if err != nil { @@ -1304,6 +1501,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out return } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + eventType := e.extractEventType(remaining[:headersLen+4]) payloadStart := 4 + headersLen @@ -1317,9 +1520,43 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) continue } + // DIAGNOSTIC: Log all received event types for debugging + log.Debugf("kiro: streamToChannel received event type: %s", eventType) + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("kiro: streamToChannel event payload: %s", string(payload)) + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} + return + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} + return + } + // Send message_start on first event if !messageStartSent { msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens) @@ -1364,9 +1601,19 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Handle text content + // Handle text content with duplicate detection if contentDelta != "" { + // Check for duplicate content - skip if identical to last content event + // Based on AIClient-2-API implementation for Kiro + if contentDelta == lastContentEvent { + log.Debugf("kiro: skipping duplicate content event (len: %d)", len(contentDelta)) + continue + } + lastContentEvent = contentDelta + outputLen += len(contentDelta) + // Accumulate content for streaming token calculation + accumulatedContent.WriteString(contentDelta) // Start text content block if needed if !isTextBlockOpen { contentBlockIndex++ @@ -1538,8 +1785,32 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out } } - // Fallback for output tokens if not received from upstream - if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Streaming token calculation - calculate output tokens from accumulated content + // This provides more accurate token counting than simple character division + if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { + // Try to use tiktoken for accurate counting + if enc, err := tokenizerForModel(model); err == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + totalUsage.OutputTokens = int64(tokenCount) + log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) + } else { + // Fallback on count error: estimate from character count + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) + } + } else { + // Fallback: estimate from character count (roughly 4 chars per token) + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) + } + } else if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Legacy fallback using outputLen totalUsage.OutputTokens = int64(outputLen / 4) if totalUsage.OutputTokens == 0 { totalUsage.OutputTokens = 1 @@ -1553,9 +1824,18 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out stopReason = "tool_use" } - // Send message_delta and message_stop - msgStop := e.buildClaudeMessageStopEvent(stopReason, totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + // Send message_delta event + msgDelta := e.buildClaudeMessageDeltaEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send message_stop event separately + msgStop := e.buildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} @@ -1646,8 +1926,8 @@ func (e *KiroExecutor) buildClaudeContentBlockStopEvent(index int) []byte { return []byte("event: content_block_stop\ndata: " + string(result)) } -func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo usage.Detail) []byte { - // First message_delta +// buildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage. +func (e *KiroExecutor) buildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte { deltaEvent := map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{ @@ -1660,14 +1940,16 @@ func (e *KiroExecutor) buildClaudeMessageStopEvent(stopReason string, usageInfo }, } deltaResult, _ := json.Marshal(deltaEvent) + return []byte("event: message_delta\ndata: " + string(deltaResult)) +} - // Then message_stop +// buildClaudeMessageStopOnlyEvent creates only the message_stop event. +func (e *KiroExecutor) buildClaudeMessageStopOnlyEvent() []byte { stopEvent := map[string]interface{}{ "type": "message_stop", } stopResult, _ := json.Marshal(stopEvent) - - return []byte("event: message_delta\ndata: " + string(deltaResult) + "\n\nevent: message_stop\ndata: " + string(stopResult)) + return []byte("event: message_stop\ndata: " + string(stopResult)) } // buildClaudeFinalEvent constructs the final Claude-style event. @@ -1873,6 +2155,12 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c return fmt.Errorf("failed to read message: %w", err) } + // Validate headersLen to prevent slice out of bounds + if headersLen+4 > uint32(len(remaining)) { + log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining)) + continue + } + eventType := e.extractEventType(remaining[:headersLen+4]) payloadStart := 4 + headersLen @@ -1886,6 +2174,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c var event map[string]interface{} if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) continue } @@ -1983,9 +2272,19 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c } totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens - // Always use end_turn (no tool_use support) - msgStop := e.buildClaudeMessageStopEvent("end_turn", totalUsage) - sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + // Send message_delta event + msgDelta := e.buildClaudeMessageDeltaEvent("end_turn", totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + c.Writer.Write([]byte(chunk + "\n\n")) + } + } + c.Writer.Flush() + + // Send message_stop event separately + msgStop := e.buildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) for _, chunk := range sseData { if chunk != "" { c.Writer.Write([]byte(chunk + "\n\n")) @@ -2057,6 +2356,128 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } +// ============================================================================ +// Message Merging Support - Merge adjacent messages with the same role +// Based on AIClient-2-API implementation +// ============================================================================ + +// mergeAdjacentMessages merges adjacent messages with the same role. +// This reduces API call complexity and improves compatibility. +// Based on AIClient-2-API implementation. +func mergeAdjacentMessages(messages []gjson.Result) []gjson.Result { + if len(messages) <= 1 { + return messages + } + + var merged []gjson.Result + for _, msg := range messages { + if len(merged) == 0 { + merged = append(merged, msg) + continue + } + + lastMsg := merged[len(merged)-1] + currentRole := msg.Get("role").String() + lastRole := lastMsg.Get("role").String() + + if currentRole == lastRole { + // Merge content from current message into last message + mergedContent := mergeMessageContent(lastMsg, msg) + // Create a new merged message JSON + mergedMsg := createMergedMessage(lastRole, mergedContent) + merged[len(merged)-1] = gjson.Parse(mergedMsg) + } else { + merged = append(merged, msg) + } + } + + return merged +} + +// mergeMessageContent merges the content of two messages with the same role. +// Handles both string content and array content (with text, tool_use, tool_result blocks). +func mergeMessageContent(msg1, msg2 gjson.Result) string { + content1 := msg1.Get("content") + content2 := msg2.Get("content") + + // Extract content blocks from both messages + var blocks1, blocks2 []map[string]interface{} + + if content1.IsArray() { + for _, block := range content1.Array() { + blocks1 = append(blocks1, blockToMap(block)) + } + } else if content1.Type == gjson.String { + blocks1 = append(blocks1, map[string]interface{}{ + "type": "text", + "text": content1.String(), + }) + } + + if content2.IsArray() { + for _, block := range content2.Array() { + blocks2 = append(blocks2, blockToMap(block)) + } + } else if content2.Type == gjson.String { + blocks2 = append(blocks2, map[string]interface{}{ + "type": "text", + "text": content2.String(), + }) + } + + // Merge text blocks if both end/start with text + if len(blocks1) > 0 && len(blocks2) > 0 { + if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" { + // Merge the last text block of msg1 with the first text block of msg2 + text1 := blocks1[len(blocks1)-1]["text"].(string) + text2 := blocks2[0]["text"].(string) + blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2 + blocks2 = blocks2[1:] // Remove the merged block from blocks2 + } + } + + // Combine all blocks + allBlocks := append(blocks1, blocks2...) + + // Convert to JSON + result, _ := json.Marshal(allBlocks) + return string(result) +} + +// blockToMap converts a gjson.Result block to a map[string]interface{} +func blockToMap(block gjson.Result) map[string]interface{} { + result := make(map[string]interface{}) + block.ForEach(func(key, value gjson.Result) bool { + if value.IsObject() { + result[key.String()] = blockToMap(value) + } else if value.IsArray() { + var arr []interface{} + for _, item := range value.Array() { + if item.IsObject() { + arr = append(arr, blockToMap(item)) + } else { + arr = append(arr, item.Value()) + } + } + result[key.String()] = arr + } else { + result[key.String()] = value.Value() + } + return true + }) + return result +} + +// createMergedMessage creates a JSON string for a merged message +func createMergedMessage(role string, content string) string { + msg := map[string]interface{}{ + "role": role, + "content": json.RawMessage(content), + } + result, _ := json.Marshal(msg) + return string(result) +} + // ============================================================================ // Tool Calling Support - Embedded tool call parsing and input buffering // Based on amq2api and AIClient-2-API implementations @@ -2079,8 +2500,6 @@ var ( whitespaceCollapsePattern = regexp.MustCompile(`\s+`) // trailingCommaPattern matches trailing commas before closing braces/brackets trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) - // unquotedKeyPattern matches unquoted JSON keys that need quoting - unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`) ) // parseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. @@ -2246,14 +2665,208 @@ func findMatchingBracket(text string, startPos int) int { } // repairJSON attempts to fix common JSON issues that may occur in tool call arguments. -// Based on AIClient-2-API's JSON repair implementation. +// Based on AIClient-2-API's JSON repair implementation with a more conservative strategy. +// +// Conservative repair strategy: +// 1. First try to parse JSON directly - if valid, return as-is +// 2. Only attempt repair if parsing fails +// 3. After repair, validate the result - if still invalid, return original +// +// Handles incomplete JSON by balancing brackets and removing trailing incomplete content. // Uses pre-compiled regex patterns for performance. -func repairJSON(raw string) string { +func repairJSON(jsonString string) string { + // Handle empty or invalid input + if jsonString == "" { + return "{}" + } + + str := strings.TrimSpace(jsonString) + if str == "" { + return "{}" + } + + // CONSERVATIVE STRATEGY: First try to parse directly + // If the JSON is already valid, return it unchanged + var testParse interface{} + if err := json.Unmarshal([]byte(str), &testParse); err == nil { + log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged") + return str + } + + log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair") + originalStr := str // Keep original for fallback + + // First, escape unescaped newlines/tabs within JSON string values + str = escapeNewlinesInStrings(str) // Remove trailing commas before closing braces/brackets - repaired := trailingCommaPattern.ReplaceAllString(raw, "$1") - // Fix unquoted keys (basic attempt - handles simple cases) - repaired = unquotedKeyPattern.ReplaceAllString(repaired, `$1"$2":`) - return repaired + str = trailingCommaPattern.ReplaceAllString(str, "$1") + + // Calculate bracket balance to detect incomplete JSON + braceCount := 0 // {} balance + bracketCount := 0 // [] balance + inString := false + escape := false + lastValidIndex := -1 + + for i := 0; i < len(str); i++ { + char := str[i] + + // Handle escape sequences + if escape { + escape = false + continue + } + + if char == '\\' { + escape = true + continue + } + + // Handle string boundaries + if char == '"' { + inString = !inString + continue + } + + // Skip characters inside strings (they don't affect bracket balance) + if inString { + continue + } + + // Track bracket balance + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + + // Record last valid position (where brackets are balanced or positive) + if braceCount >= 0 && bracketCount >= 0 { + lastValidIndex = i + } + } + + // If brackets are unbalanced, try to repair + if braceCount > 0 || bracketCount > 0 { + // Truncate to last valid position if we have incomplete content + if lastValidIndex > 0 && lastValidIndex < len(str)-1 { + // Check if truncation would help (only truncate if there's trailing garbage) + truncated := str[:lastValidIndex+1] + // Recount brackets after truncation + braceCount = 0 + bracketCount = 0 + inString = false + escape = false + for i := 0; i < len(truncated); i++ { + char := truncated[i] + if escape { + escape = false + continue + } + if char == '\\' { + escape = true + continue + } + if char == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + } + str = truncated + } + + // Add missing closing brackets + for braceCount > 0 { + str += "}" + braceCount-- + } + for bracketCount > 0 { + str += "]" + bracketCount-- + } + } + + // CONSERVATIVE STRATEGY: Validate repaired JSON + // If repair didn't produce valid JSON, return original string + if err := json.Unmarshal([]byte(str), &testParse); err != nil { + log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original") + return originalStr + } + + log.Debugf("kiro: repairJSON - successfully repaired JSON") + return str +} + +// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters +// that appear inside JSON string values. This handles cases where streaming fragments +// contain unescaped control characters within string content. +func escapeNewlinesInStrings(raw string) string { + var result strings.Builder + result.Grow(len(raw) + 100) // Pre-allocate with some extra space + + inString := false + escaped := false + + for i := 0; i < len(raw); i++ { + c := raw[i] + + if escaped { + // Previous character was backslash, this is an escape sequence + result.WriteByte(c) + escaped = false + continue + } + + if c == '\\' && inString { + // Start of escape sequence + result.WriteByte(c) + escaped = true + continue + } + + if c == '"' { + // Toggle string state + inString = !inString + result.WriteByte(c) + continue + } + + if inString { + // Inside a string, escape control characters + switch c { + case '\n': + result.WriteString("\\n") + case '\r': + result.WriteString("\\r") + case '\t': + result.WriteString("\\t") + default: + result.WriteByte(c) + } + } else { + result.WriteByte(c) + } + } + + return result.String() } // processToolUseEvent handles a toolUseEvent from the Kiro stream. @@ -2330,6 +2943,8 @@ func (e *KiroExecutor) processToolUseEvent(event map[string]interface{}, current // Accumulate input fragments if currentToolUse != nil && inputFragment != "" { + // Accumulate fragments directly - they form valid JSON when combined + // The fragments are already decoded from JSON, so we just concatenate them currentToolUse.inputBuffer.WriteString(inputFragment) log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.inputBuffer.Len()) } diff --git a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go index df75cc070..d56c94acf 100644 --- a/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go +++ b/internal/translator/kiro/openai/chat-completions/kiro_openai_response.go @@ -171,7 +171,7 @@ func convertClaudeEventToOpenAI(jsonStr string, model string) []string { return results case "message_delta": - // Final message delta with stop_reason + // Final message delta with stop_reason and usage stopReason := root.Get("delta.stop_reason").String() if stopReason != "" { finishReason := "stop" @@ -196,6 +196,19 @@ func convertClaudeEventToOpenAI(jsonStr string, model string) []string { }, }, } + + // Extract and include usage information from message_delta event + usage := root.Get("usage") + if usage.Exists() { + inputTokens := usage.Get("input_tokens").Int() + outputTokens := usage.Get("output_tokens").Int() + response["usage"] = map[string]interface{}{ + "prompt_tokens": inputTokens, + "completion_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + } + } + result, _ := json.Marshal(response) results = append(results, string(result)) }