diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index e9fe758d3..a420da100 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -11,6 +11,7 @@ import ( "context" "encoding/json" "fmt" + "sort" "strings" "github.com/tidwall/gjson" @@ -21,6 +22,54 @@ var ( dataTag = []byte("data:") ) +type ConvertCodexResponseToClaudeParams struct { + HasToolCall bool + NextContentBlockIndex int + + BlockIndexByKey map[string]int + StartedKeys map[string]bool + StoppedKeys map[string]bool + ToolKeyByOutputIndex map[int64]string +} + +func (p *ConvertCodexResponseToClaudeParams) indexForKey(key string) int { + if idx, ok := p.BlockIndexByKey[key]; ok { + return idx + } + idx := p.NextContentBlockIndex + p.NextContentBlockIndex++ + p.BlockIndexByKey[key] = idx + return idx +} + +func (p *ConvertCodexResponseToClaudeParams) toolBlockKey(outputIndex int64) string { + if key := p.ToolKeyByOutputIndex[outputIndex]; key != "" { + return key + } + key := fmt.Sprintf("tool_use:%d", outputIndex) + p.ToolKeyByOutputIndex[outputIndex] = key + return key +} + +func codexClaudeThinkingKey(root gjson.Result) string { + outputIndex := root.Get("output_index").Int() + if summaryIndex := root.Get("summary_index"); summaryIndex.Exists() { + return fmt.Sprintf("thinking:%d:%d", outputIndex, summaryIndex.Int()) + } + if contentIndex := root.Get("content_index"); contentIndex.Exists() { + return fmt.Sprintf("thinking:%d:%d", outputIndex, contentIndex.Int()) + } + return fmt.Sprintf("thinking:%d", outputIndex) +} + +func codexClaudeTextKey(root gjson.Result) string { + outputIndex := root.Get("output_index").Int() + if contentIndex := root.Get("content_index"); contentIndex.Exists() { + return fmt.Sprintf("text:%d:%d", outputIndex, contentIndex.Int()) + } + return fmt.Sprintf("text:%d", outputIndex) +} + // ConvertCodexResponseToClaude performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates Codex API responses // into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types @@ -39,9 +88,16 @@ var ( // - []string: A slice of strings, each containing a Claude Code-compatible JSON response func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { if *param == nil { - hasToolCall := false - *param = &hasToolCall + *param = &ConvertCodexResponseToClaudeParams{ + HasToolCall: false, + NextContentBlockIndex: 0, + BlockIndexByKey: make(map[string]int), + StartedKeys: make(map[string]bool), + StoppedKeys: make(map[string]bool), + ToolKeyByOutputIndex: make(map[int64]string), + } } + p := (*param).(*ConvertCodexResponseToClaudeParams) // log.Debugf("rawJSON: %s", string(rawJSON)) if !bytes.HasPrefix(rawJSON, dataTag) { @@ -62,47 +118,123 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = "event: message_start\n" output += fmt.Sprintf("data: %s\n\n", template) } else if typeStr == "response.reasoning_summary_part.added" { + blockKey := codexClaudeThinkingKey(rootResult) + index := p.indexForKey(blockKey) + template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output = "event: content_block_start\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StartedKeys[blockKey] = true + delete(p.StoppedKeys, blockKey) } else if typeStr == "response.reasoning_summary_text.delta" { + blockKey := codexClaudeThinkingKey(rootResult) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] { + startTemplate := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}` + startTemplate, _ = sjson.Set(startTemplate, "index", index) + + output += "event: content_block_start\n" + output += fmt.Sprintf("data: %s\n\n", startTemplate) + p.StartedKeys[blockKey] = true + delete(p.StoppedKeys, blockKey) + } + template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String()) - output = "event: content_block_delta\n" + output += "event: content_block_delta\n" output += fmt.Sprintf("data: %s\n\n", template) } else if typeStr == "response.reasoning_summary_part.done" { + blockKey := codexClaudeThinkingKey(rootResult) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] || p.StoppedKeys[blockKey] { + return []string{} + } + template = `{"type":"content_block_stop","index":0}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output = "event: content_block_stop\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StoppedKeys[blockKey] = true } else if typeStr == "response.content_part.added" { + blockKey := codexClaudeTextKey(rootResult) + index := p.indexForKey(blockKey) + template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output = "event: content_block_start\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StartedKeys[blockKey] = true + delete(p.StoppedKeys, blockKey) } else if typeStr == "response.output_text.delta" { + blockKey := codexClaudeTextKey(rootResult) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] { + startTemplate := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + startTemplate, _ = sjson.Set(startTemplate, "index", index) + + output += "event: content_block_start\n" + output += fmt.Sprintf("data: %s\n\n", startTemplate) + p.StartedKeys[blockKey] = true + delete(p.StoppedKeys, blockKey) + } + template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String()) - output = "event: content_block_delta\n" + output += "event: content_block_delta\n" output += fmt.Sprintf("data: %s\n\n", template) } else if typeStr == "response.content_part.done" { + blockKey := codexClaudeTextKey(rootResult) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] || p.StoppedKeys[blockKey] { + return []string{} + } + template = `{"type":"content_block_stop","index":0}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output = "event: content_block_stop\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StoppedKeys[blockKey] = true } else if typeStr == "response.completed" { + var openBlocks []struct { + Index int + Key string + } + for key := range p.StartedKeys { + if p.StartedKeys[key] && !p.StoppedKeys[key] { + openBlocks = append(openBlocks, struct { + Index int + Key string + }{ + Index: p.indexForKey(key), + Key: key, + }) + } + } + sort.Slice(openBlocks, func(i, j int) bool { return openBlocks[i].Index < openBlocks[j].Index }) + for _, blk := range openBlocks { + stopTemplate := `{"type":"content_block_stop","index":0}` + stopTemplate, _ = sjson.Set(stopTemplate, "index", blk.Index) + + output += "event: content_block_stop\n" + output += fmt.Sprintf("data: %s\n\n", stopTemplate) + p.StoppedKeys[blk.Key] = true + } + template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - p := (*param).(*bool) - if *p { + if p.HasToolCall { template, _ = sjson.Set(template, "delta.stop_reason", "tool_use") } else { template, _ = sjson.Set(template, "delta.stop_reason", "end_turn") @@ -110,7 +242,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int()) template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int()) - output = "event: message_delta\n" + output += "event: message_delta\n" output += fmt.Sprintf("data: %s\n\n", template) output += "event: message_stop\n" output += `data: {"type":"message_stop"}` @@ -119,10 +251,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa itemResult := rootResult.Get("item") itemType := itemResult.Get("type").String() if itemType == "function_call" { - p := true - *param = &p + p.HasToolCall = true + + outputIndex := rootResult.Get("output_index").Int() + callID := itemResult.Get("call_id").String() + blockKey := fmt.Sprintf("tool_use:%d:%s", outputIndex, callID) + p.ToolKeyByOutputIndex[outputIndex] = blockKey + index := p.indexForKey(blockKey) + template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String()) { // Restore original tool name if shortened @@ -136,9 +274,11 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = "event: content_block_start\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StartedKeys[blockKey] = true + delete(p.StoppedKeys, blockKey) template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output += "event: content_block_delta\n" output += fmt.Sprintf("data: %s\n\n", template) @@ -147,21 +287,41 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa itemResult := rootResult.Get("item") itemType := itemResult.Get("type").String() if itemType == "function_call" { + outputIndex := rootResult.Get("output_index").Int() + blockKey := p.toolBlockKey(outputIndex) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] || p.StoppedKeys[blockKey] { + return []string{} + } + template = `{"type":"content_block_stop","index":0}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) output = "event: content_block_stop\n" output += fmt.Sprintf("data: %s\n\n", template) + p.StoppedKeys[blockKey] = true } } else if typeStr == "response.function_call_arguments.delta" { + outputIndex := rootResult.Get("output_index").Int() + blockKey := p.toolBlockKey(outputIndex) + index := p.indexForKey(blockKey) + + if !p.StartedKeys[blockKey] || p.StoppedKeys[blockKey] { + return []string{} + } + template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` - template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "index", index) template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String()) - output += "event: content_block_delta\n" + output = "event: content_block_delta\n" output += fmt.Sprintf("data: %s\n\n", template) } + if output == "" { + return []string{} + } return []string{output} } diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go new file mode 100644 index 000000000..f0182ccf3 --- /dev/null +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -0,0 +1,173 @@ +package claude + +import ( + "context" + "encoding/json" + "strings" + "testing" +) + +type parsedSSEEvent struct { + Event string + Data map[string]any +} + +func parseSSEEvents(t *testing.T, chunks []string) []parsedSSEEvent { + t.Helper() + + var events []parsedSSEEvent + for _, chunk := range chunks { + for _, block := range strings.Split(chunk, "\n\n") { + block = strings.TrimSpace(block) + if block == "" { + continue + } + + var eventName string + var dataStr string + + for _, line := range strings.Split(block, "\n") { + if strings.HasPrefix(line, "event: ") { + eventName = strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + } else if strings.HasPrefix(line, "data: ") { + dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + } + } + + if eventName == "" || dataStr == "" { + continue + } + + var data map[string]any + if err := json.Unmarshal([]byte(dataStr), &data); err != nil { + t.Fatalf("failed to parse SSE JSON data %q: %v", dataStr, err) + } + events = append(events, parsedSSEEvent{Event: eventName, Data: data}) + } + } + + return events +} + +func floatIndexToInt(v any) (int, bool) { + f, ok := v.(float64) + if !ok { + return 0, false + } + return int(f), true +} + +func expectedStartTypeForDelta(deltaType string) (string, bool) { + switch deltaType { + case "thinking_delta": + return "thinking", true + case "text_delta": + return "text", true + case "input_json_delta": + return "tool_use", true + default: + return "", false + } +} + +func TestConvertCodexResponseToClaude_DoesNotReuseContentBlockIndexesAcrossTypes(t *testing.T) { + originalRequest := []byte(`{"tools":[{"name":"dummy_tool","input_schema":{"type":"object"}}]}`) + + var state any + var outputs []string + + inputs := []string{ + `{"type":"response.created","response":{"id":"r1","model":"gpt-5.2"}}`, + `{"type":"response.output_item.added","output_index":0,"item":{"type":"function_call","call_id":"call_1","name":"dummy_tool"}}`, + `{"type":"response.function_call_arguments.delta","output_index":0,"delta":"{\"x\":1}"}`, + `{"type":"response.output_item.done","output_index":0,"item":{"type":"function_call"}}`, + `{"type":"response.reasoning_summary_part.added","output_index":0}`, + `{"type":"response.reasoning_summary_text.delta","output_index":0,"delta":"Thinking..."}`, + `{"type":"response.reasoning_summary_part.done","output_index":0}`, + `{"type":"response.completed","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, + } + + for _, in := range inputs { + raw := []byte("data: " + in) + out := ConvertCodexResponseToClaude(context.Background(), "", originalRequest, nil, raw, &state) + outputs = append(outputs, out...) + } + + events := parseSSEEvents(t, outputs) + + startTypeByIndex := make(map[int]string) + var toolUseIndex, thinkingIndex *int + + for _, ev := range events { + typ, _ := ev.Data["type"].(string) + switch typ { + case "content_block_start": + idx, ok := floatIndexToInt(ev.Data["index"]) + if !ok { + t.Fatalf("content_block_start missing numeric index: %#v", ev.Data["index"]) + } + + contentBlock, ok := ev.Data["content_block"].(map[string]any) + if !ok { + t.Fatalf("content_block_start missing content_block object: %#v", ev.Data["content_block"]) + } + cbType, _ := contentBlock["type"].(string) + if cbType == "" { + t.Fatalf("content_block_start missing content_block.type: %#v", contentBlock) + } + + if prev, exists := startTypeByIndex[idx]; exists && prev != cbType { + t.Fatalf("content_block index %d reused for different types: %q then %q", idx, prev, cbType) + } + startTypeByIndex[idx] = cbType + + switch cbType { + case "tool_use": + tmp := idx + toolUseIndex = &tmp + case "thinking": + tmp := idx + thinkingIndex = &tmp + } + case "content_block_delta": + idx, ok := floatIndexToInt(ev.Data["index"]) + if !ok { + t.Fatalf("content_block_delta missing numeric index: %#v", ev.Data["index"]) + } + delta, ok := ev.Data["delta"].(map[string]any) + if !ok { + t.Fatalf("content_block_delta missing delta object: %#v", ev.Data["delta"]) + } + deltaType, _ := delta["type"].(string) + if deltaType == "" { + t.Fatalf("content_block_delta missing delta.type: %#v", delta) + } + expectedStartType, ok := expectedStartTypeForDelta(deltaType) + if !ok { + continue + } + startType, exists := startTypeByIndex[idx] + if !exists { + t.Fatalf("content_block_delta for index %d (%s) without a prior content_block_start", idx, deltaType) + } + if startType != expectedStartType { + t.Fatalf("content_block_delta type %q mismatched for index %d: started as %q", deltaType, idx, startType) + } + case "content_block_stop": + idx, ok := floatIndexToInt(ev.Data["index"]) + if !ok { + t.Fatalf("content_block_stop missing numeric index: %#v", ev.Data["index"]) + } + if _, exists := startTypeByIndex[idx]; !exists { + t.Fatalf("content_block_stop for unknown index %d", idx) + } + } + } + + if toolUseIndex == nil || thinkingIndex == nil { + t.Fatalf("expected both tool_use and thinking content_block_start events; got tool_use=%v thinking=%v", toolUseIndex, thinkingIndex) + } + if *toolUseIndex == *thinkingIndex { + t.Fatalf("tool_use and thinking blocks share the same index %d; indexes must be unique within a message", *toolUseIndex) + } +}