diff --git a/go.mod b/go.mod index ab0265f..032c766 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/tuannvm/slack-mcp-client go 1.24.4 require ( + github.com/aws/aws-sdk-go-v2/config v1.29.4 + github.com/aws/aws-sdk-go-v2/service/s3vectors v1.4.10 github.com/joho/godotenv v1.5.1 github.com/mark3labs/mcp-go v0.42.0 github.com/openai/openai-go v1.8.2 @@ -24,6 +26,18 @@ require ( github.com/Masterminds/sprig/v3 v3.2.3 // indirect github.com/PuerkitoBio/goquery v1.8.1 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect + github.com/aws/aws-sdk-go-v2 v1.39.4 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.57 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.12 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index ec3b742..5c934d4 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,34 @@ github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZ github.com/andybalholm/cascadia v1.3.1/go.mod h1:R4bJ1UQfqADjvDa4P6HZHLh/3OxWWEqc0Sk8XGwHqvA= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= +github.com/aws/aws-sdk-go-v2 v1.39.4 h1:qTsQKcdQPHnfGYBBs+Btl8QwxJeoWcOcPcixK90mRhg= +github.com/aws/aws-sdk-go-v2 v1.39.4/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.29.4 h1:ObNqKsDYFGr2WxnoXKOhCvTlf3HhwtoGgc+KmZ4H5yg= +github.com/aws/aws-sdk-go-v2/config v1.29.4/go.mod h1:j2/AF7j/qxVmsNIChw1tWfsVKOayJoGRDjg1Tgq7NPk= +github.com/aws/aws-sdk-go-v2/credentials v1.17.57 h1:kFQDsbdBAR3GZsB8xA+51ptEnq9TIj3tS4MuP5b+TcQ= +github.com/aws/aws-sdk-go-v2/credentials v1.17.57/go.mod h1:2kerxPUUbTagAr/kkaHiqvj/bcYHzi2qiJS/ZinllU0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11/go.mod h1:7bUb2sSr2MZ3M/N+VyETLTQtInemHXb/Fl3s8CLzm0Y= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/s3vectors v1.4.10 h1:hgJrhznAL6SjFZAqNIexiE9L7Zjc5PMGmwPWNtTE3zc= +github.com/aws/aws-sdk-go-v2/service/s3vectors v1.4.10/go.mod h1:gJNoydxeaa5Av62mqcKTcA/9oFJnnZRseWfDmPKfGv8= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.12 h1:fqg6c1KVrc3SYWma/egWue5rKI4G2+M4wMQN2JosNAA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.12/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= diff --git a/helm-chart/slack-mcp-client/templates/deployment.yaml b/helm-chart/slack-mcp-client/templates/deployment.yaml index 2482147..de52b9b 100644 --- a/helm-chart/slack-mcp-client/templates/deployment.yaml +++ b/helm-chart/slack-mcp-client/templates/deployment.yaml @@ -24,9 +24,10 @@ spec: {{- end }} securityContext: {{- toYaml .Values.podSecurityContext | nindent 8 }} - {{- if and .Values.serviceAccount.create .Values.serviceAccount.clusterRoleName }} - serviceAccountName: {{ include "slack-mcp-client.fullname" . }} + {{- if .Values.serviceAccount.create }} + serviceAccountName: {{ .Values.serviceAccount.name | default (include "slack-mcp-client.fullname" .) }} {{- end }} + {{- if .Values.initContainers }} initContainers: {{- range .Values.initContainers }} diff --git a/helm-chart/slack-mcp-client/templates/service-account.yaml b/helm-chart/slack-mcp-client/templates/service-account.yaml index 87d0334..889ad70 100644 --- a/helm-chart/slack-mcp-client/templates/service-account.yaml +++ b/helm-chart/slack-mcp-client/templates/service-account.yaml @@ -1,8 +1,12 @@ -{{- if and .Values.serviceAccount.create .Values.serviceAccount.clusterRoleName }} +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: - name: {{ include "slack-mcp-client.fullname" . }} + name: {{ .Values.serviceAccount.name | default (include "slack-mcp-client.fullname" .) }} labels: {{- include "slack-mcp-client.labels" . | nindent 4 }} + {{- with .Values.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} {{- end }} diff --git a/internal/config/config.go b/internal/config/config.go index 3fdbcf1..2e7ba91 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,17 +22,19 @@ const ( // Config represents the main application configuration type Config struct { - Version string `json:"version"` - Slack SlackConfig `json:"slack"` - LLM LLMConfig `json:"llm"` - MCPServers map[string]MCPServerConfig `json:"mcpServers"` - RAG RAGConfig `json:"rag,omitempty"` - Monitoring MonitoringConfig `json:"monitoring,omitempty"` - Timeouts TimeoutConfig `json:"timeouts,omitempty"` - Retry RetryConfig `json:"retry,omitempty"` - Reload ReloadConfig `json:"reload,omitempty"` - Observability ObservabilityConfig `json:"observability,omitempty"` - UseStdIOClient bool `json:"useStdIOClient,omitempty"` // Use terminal client instead of a real slack bot, for local development + Version string `json:"version"` + Slack SlackConfig `json:"slack"` + LLM LLMConfig `json:"llm"` + MCPServers map[string]MCPServerConfig `json:"mcpServers"` + QueryEnhancementProvider string `json:"queryEnhancementProvider,omitempty"` // Optional: LLM provider for query enhancement (applies to all queries) + QueryEnhancementPromptFile string `json:"queryEnhancementPromptFile,omitempty"` // Optional: Path to custom query enhancement prompt file + RAG RAGConfig `json:"rag,omitempty"` + Monitoring MonitoringConfig `json:"monitoring,omitempty"` + Timeouts TimeoutConfig `json:"timeouts,omitempty"` + Retry RetryConfig `json:"retry,omitempty"` + Reload ReloadConfig `json:"reload,omitempty"` + Observability ObservabilityConfig `json:"observability,omitempty"` + UseStdIOClient bool `json:"useStdIOClient,omitempty"` // Use terminal client instead of a real slack bot, for local development } // SlackConfig contains Slack-specific configuration @@ -109,26 +111,37 @@ type MCPToolsConfig struct { // RAGConfig contains RAG system configuration type RAGConfig struct { - Enabled bool `json:"enabled,omitempty"` - Provider string `json:"provider,omitempty"` - ChunkSize int `json:"chunkSize,omitempty"` - Providers map[string]RAGProviderConfig `json:"providers,omitempty"` + Enabled bool `json:"enabled,omitempty"` + Provider string `json:"provider,omitempty"` + ChunkSize int `json:"chunkSize,omitempty"` + EmbeddingProvider string `json:"embeddingProvider,omitempty"` // Optional: Embedding provider (voyage, openai, cohere, etc.) + Providers map[string]RAGProviderConfig `json:"providers,omitempty"` + EmbeddingProviders map[string]RAGEmbeddingProviderConfig `json:"embeddingProviders,omitempty"` // Embedding provider configs } // RAGProviderConfig contains RAG provider-specific settings // TODO: Refactor this to use a common interface for all RAG providers, can use environment variables to configure the different providers type RAGProviderConfig struct { DatabasePath string `json:"databasePath,omitempty"` // Simple provider: path to JSON database - IndexName string `json:"indexName,omitempty"` // OpenAI provider: vector store name + IndexName string `json:"indexName,omitempty"` // OpenAI/S3 provider: vector store/index name VectorStoreID string `json:"vectorStoreId,omitempty"` // OpenAI provider: existing vector store ID Dimensions int `json:"dimensions,omitempty"` // OpenAI provider: embedding dimensions SimilarityMetric string `json:"similarityMetric,omitempty"` // OpenAI provider: similarity metric - MaxResults int `json:"maxResults,omitempty"` // OpenAI provider: maximum search results - ScoreThreshold float64 `json:"scoreThreshold,omitempty"` // OpenAI provider: score threshold + MaxResults int `json:"maxResults,omitempty"` // OpenAI/S3 provider: maximum search results + ScoreThreshold float64 `json:"scoreThreshold,omitempty"` // OpenAI/S3 provider: score threshold RewriteQuery bool `json:"rewriteQuery,omitempty"` // OpenAI provider: rewrite query VectorStoreNameRegex string `json:"vectorStoreNameRegex,omitempty"` // OpenAI provider: vector store name regex VectorStoreMetadataKey string `json:"vectorStoreMetadataKey,omitempty"` // OpenAI provider: vector store metadata key VectorStoreMetadataValue string `json:"vectorStoreMetadataValue,omitempty"` // OpenAI provider: vector store metadata value + BucketName string `json:"bucketName,omitempty"` // S3 provider: S3 bucket name + Region string `json:"region,omitempty"` // S3 provider: AWS region + DateFilterField string `json:"dateFilterField,omitempty"` // Date filter metadata field name + DateRangeWindowDays int `json:"dateRangeWindowDays,omitempty"` // Days to expand date range backward (default: 7) +} + +// RAGEmbeddingProviderConfig contains embedding provider-specific settings +type RAGEmbeddingProviderConfig struct { + APIKey string `json:"apiKey,omitempty"` // API key for the embedding provider } // MonitoringConfig contains monitoring and observability settings diff --git a/internal/config/validation.go b/internal/config/validation.go index 1935d80..7029941 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -194,6 +194,12 @@ func (c *Config) SubstituteEnvironmentVariables() { c.Observability.ServiceName = substituteEnvVars(c.Observability.ServiceName) c.Observability.ServiceVersion = substituteEnvVars(c.Observability.ServiceVersion) + // Substitute in RAG Embedding Providers configuration + for name, provider := range c.RAG.EmbeddingProviders { + provider.APIKey = substituteEnvVars(provider.APIKey) + c.RAG.EmbeddingProviders[name] = provider + } + } // substituteEnvVars replaces ${VAR_NAME} patterns with environment variable values diff --git a/internal/handlers/llm_mcp_bridge.go b/internal/handlers/llm_mcp_bridge.go index 1008c37..0ad9f87 100644 --- a/internal/handlers/llm_mcp_bridge.go +++ b/internal/handlers/llm_mcp_bridge.go @@ -224,13 +224,14 @@ func NewLLMMCPBridgeFromClientsWithLogLevel(mcpClients interface{}, stdLogger *l return NewLLMMCPBridgeWithLogLevel(interfaceClients, stdLogger, discoveredTools, logLevel, llmRegistry, cfg) } -// ProcessLLMResponse processes an LLM response, expecting a specific JSON tool call format. -// It no longer uses natural language detection. -func (b *LLMMCPBridge) ProcessLLMResponse(ctx context.Context, llmResponse *llms.ContentChoice, _ string, extraArgs map[string]interface{}) (string, error) { +// ExtractToolCall extracts tool call information from LLM response +// Returns nil if no tool call is detected +func (b *LLMMCPBridge) ExtractToolCall(llmResponse *llms.ContentChoice) (*ToolCall, error) { var toolCall *ToolCall var err error + + // Check for native tool calls first funcCall := llmResponse.FuncCall - // Check for a tool call in JSON format if len(llmResponse.ToolCalls) > 0 { funcCall = llmResponse.ToolCalls[0].FunctionCall } @@ -238,40 +239,47 @@ func (b *LLMMCPBridge) ProcessLLMResponse(ctx context.Context, llmResponse *llms if funcCall != nil { toolCall, err = b.getToolCall(funcCall) if err != nil { - return "", err + return nil, err } } else { + // Fallback: try to detect JSON tool call in Content toolCall = b.detectSpecificJSONToolCall(llmResponse.Content) } - if toolCall != nil { - // Execute the tool call - result, err := b.executeToolCall(ctx, toolCall, extraArgs) - if err != nil { - // Check if it's already a domain error - var errorMessage string - if customErrors.IsDomainError(err) { - // Extract structured information from the domain error - code, _ := customErrors.GetErrorCode(err) - b.logger.ErrorKV("Failed to execute tool call", - "error", err.Error(), - "error_code", code, - "tool", toolCall.Tool) - errorMessage = fmt.Sprintf("Error executing tool call: %v (code: %s)", err, code) - } else { - b.logger.ErrorKV("Failed to execute tool call", - "error", err.Error(), - "tool", toolCall.Tool) - errorMessage = fmt.Sprintf("Error executing tool call: %v", err) - } + return toolCall, nil +} - return errorMessage, nil +// ExecuteToolCall executes a tool call and returns the result +func (b *LLMMCPBridge) ExecuteToolCall(ctx context.Context, toolCall *ToolCall, extraArgs map[string]interface{}) (string, error) { + if toolCall == nil { + return "", fmt.Errorf("toolCall cannot be nil") + } + + // Execute the tool call + result, err := b.executeToolCall(ctx, toolCall, extraArgs) + if err != nil { + // Check if it's already a domain error + var errorMessage string + if customErrors.IsDomainError(err) { + // Extract structured information from the domain error + code, _ := customErrors.GetErrorCode(err) + b.logger.ErrorKV("Failed to execute tool call", + "error", err.Error(), + "error_code", code, + "tool", toolCall.Tool) + errorMessage = err.Error() + } else { + // Wrap as domain error + domainErr := customErrors.WrapMCPError(err, "tool_execution_failed", + fmt.Sprintf("Failed to execute tool '%s'", toolCall.Tool)) + b.logger.ErrorKV("Failed to execute tool call", "error", domainErr.Error(), "tool", toolCall.Tool) + errorMessage = domainErr.Error() } - return result, nil + return "", fmt.Errorf("%s", errorMessage) } - // Just return the LLM response as-is if no tool call was detected - return llmResponse.Content, nil + b.logger.DebugKV("Tool call executed successfully", "tool", toolCall.Tool, "result_length", len(result)) + return result, nil } // ToolCall represents the expected JSON structure for a tool call from the LLM diff --git a/internal/observability/langfuse.go b/internal/observability/langfuse.go index c28fa0d..0d5732c 100644 --- a/internal/observability/langfuse.go +++ b/internal/observability/langfuse.go @@ -197,17 +197,23 @@ func (p *LangfuseProvider) SetOutput(span OtelTrace.Span, output string) { } func (p *LangfuseProvider) SetTokenUsage(span OtelTrace.Span, promptTokens, completionTokens, reasoningTokens, totalTokens int) { - // Langfuse usage format + // Langfuse usage format - uses "input", "output", "total" field names + // Map our standard token names to Langfuse's expected format usageDetails := map[string]int{ - "prompt_tokens": promptTokens, - "completion_tokens": completionTokens, - "total_tokens": totalTokens, - "reasoning_tokens": reasoningTokens, + "input": promptTokens, + "output": completionTokens, + "total": totalTokens, + } + + // Add reasoning tokens as a separate metadata field since Langfuse doesn't have a standard field for it + if reasoningTokens > 0 { + usageDetails["reasoning_tokens"] = reasoningTokens } if usageJSON, err := json.Marshal(usageDetails); err == nil { span.SetAttributes( attribute.String("langfuse.observation.usage_details", string(usageJSON)), + // Also set OpenTelemetry standard fields for compatibility attribute.Int("llm.token_count.prompt_tokens", promptTokens), attribute.Int("llm.token_count.completion_tokens", completionTokens), attribute.Int("llm.token_count.total_tokens", totalTokens), diff --git a/internal/rag/client.go b/internal/rag/client.go index 8be0b62..2e43635 100644 --- a/internal/rag/client.go +++ b/internal/rag/client.go @@ -4,13 +4,20 @@ package rag import ( "context" "fmt" + "sort" "strings" + "time" + + "github.com/tuannvm/slack-mcp-client/internal/observability" ) // Client wraps vector providers to implement the MCP tool interface // This allows the LLM-MCP bridge to treat RAG as a regular MCP tool type Client struct { - provider VectorProvider + provider VectorProvider + embeddingProvider EmbeddingProvider // Interface for embedding providers (Voyage, OpenAI, etc.) + config map[string]interface{} // Raw config for accessing provider-specific settings + tracingHandler interface{} // Tracing handler for observability (optional) } // NewClient creates a new RAG client with simple provider (legacy compatibility) @@ -27,11 +34,13 @@ func NewClient(ragDatabase string) *Client { _ = simpleProvider.Initialize(context.Background()) return &Client{ provider: simpleProvider, + config: config, } } return &Client{ provider: provider, + config: config, } } @@ -50,9 +59,21 @@ func NewClientWithProvider(providerType string, config map[string]interface{}) ( return &Client{ provider: provider, + config: config, }, nil } +// SetEmbeddingProvider sets the embedding provider for enhanced RAG search +// Query enhancement is now done before RAG search in the Slack client layer +func (c *Client) SetEmbeddingProvider(embeddingProvider EmbeddingProvider) { + c.embeddingProvider = embeddingProvider +} + +// SetTracingHandler sets the tracing handler for observability +func (c *Client) SetTracingHandler(tracingHandler interface{}) { + c.tracingHandler = tracingHandler +} + // CallTool implements the MCP tool interface for RAG operations func (c *Client) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (string, error) { if args == nil { @@ -71,7 +92,7 @@ func (c *Client) CallTool(ctx context.Context, toolName string, args map[string] } } -// handleRAGSearch processes search requests +// handleRAGSearch processes search requests with enhanced pipeline func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{}) (string, error) { // Extract and validate query parameter query, err := c.extractStringParam(args, "query", true) @@ -79,10 +100,130 @@ func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{ return "", err } - // Perform search using the provider - results, err := c.provider.Search(ctx, query, SearchOptions{}) - if err != nil { - return "", fmt.Errorf("search failed: %w", err) + // Build search options + // Extract max_results from config, default to 20 + maxResults := 20 + if c.config != nil { + if maxResultsFloat, ok := c.config["max_results"].(float64); ok { + maxResults = int(maxResultsFloat) + } else if maxResultsInt, ok := c.config["max_results"].(int); ok { + maxResults = maxResultsInt + } + } + + searchOpts := SearchOptions{ + Limit: maxResults, + Metadata: make(map[string]string), + } + + // 2. Date filter logging + if queryMetadataRaw, ok := args["query_metadata"]; ok { + if metadata, ok := queryMetadataRaw.(*MetadataFilters); ok && metadata != nil { + // Extract date filter if present (LLM provides the exact list of dates) + if len(metadata.Dates) > 0 { + fmt.Printf("[RAG Date Filter] Detected temporal query with %d dates from LLM\n", len(metadata.Dates)) + fmt.Printf("[RAG Date Filter] Dates: %v\n", metadata.Dates) + + // Use the dates directly from LLM - no expansion needed + searchOpts.DateFilter = metadata.Dates + } else { + fmt.Printf("[RAG Date Filter] No date filter - non-temporal query\n") + } + } + } else { + fmt.Printf("[RAG Date Filter] No query metadata provided\n") + } + + if c.embeddingProvider != nil { + // Create embedding span if tracing is enabled + var embResult *EmbeddingResult + if tracer, ok := c.tracingHandler.(observability.TracingHandler); ok && tracer != nil { + embCtx, embSpan := tracer.StartSpan(ctx, "query-embedding-creation", "embedding", query, map[string]string{ + "provider": "voyage", + }) + + startTime := time.Now() + result, err := c.embeddingProvider.EmbedQuery(embCtx, query) + duration := time.Since(startTime) + + tracer.SetDuration(embSpan, duration) + + if err != nil { + tracer.RecordError(embSpan, err, "ERROR") + embSpan.End() + return "", fmt.Errorf("failed to embed query: %w", err) + } + + embResult = result + + // Set usage and cost details + + // Set token usage (input tokens only for embeddings) + tracer.SetTokenUsage(embSpan, result.TokensUsed, 0, 0, result.TokensUsed) + + // Set output: embedding dimensions + tracer.SetOutput(embSpan, fmt.Sprintf("Generated %d-dimensional embedding (%d tokens)", + len(result.Embedding), result.TokensUsed)) + + tracer.RecordSuccess(embSpan, fmt.Sprintf("Embedding generated: model=%s", result.Model)) + embSpan.End() + } else { + // No tracing, just call embedding + result, err := c.embeddingProvider.EmbedQuery(ctx, query) + if err != nil { + return "", fmt.Errorf("failed to embed query: %w", err) + } + embResult = result + } + + searchOpts.QueryVector = embResult.Embedding + } + + // 3. S3 search parameters logging + fmt.Printf("[RAG Search] Query: '%s'\n", query) + fmt.Printf("[RAG Search] Max results: %d\n", searchOpts.Limit) + fmt.Printf("[RAG Search] Has embedding vector: %v (dimensions: %d)\n", + len(searchOpts.QueryVector) > 0, len(searchOpts.QueryVector)) + fmt.Printf("[RAG Search] Date filter count: %d dates\n", len(searchOpts.DateFilter)) + + // 4. Vector search/retrieval with tracing + var results []SearchResult + if tracer, ok := c.tracingHandler.(observability.TracingHandler); ok && tracer != nil { + // Create retriever span for vector store query + retrieverCtx, retrieverSpan := tracer.StartSpan(ctx, "vector-search", "retriever", query, map[string]string{ + "provider": fmt.Sprintf("%T", c.provider), + "max_results": fmt.Sprintf("%d", searchOpts.Limit), + "has_embedding_vector": fmt.Sprintf("%t", len(searchOpts.QueryVector) > 0), + "embedding_dimensions": fmt.Sprintf("%d", len(searchOpts.QueryVector)), + "date_filter_count": fmt.Sprintf("%d", len(searchOpts.DateFilter)), + }) + + startTime := time.Now() + searchResults, err := c.provider.Search(retrieverCtx, query, searchOpts) + duration := time.Since(startTime) + + tracer.SetDuration(retrieverSpan, duration) + + if err != nil { + tracer.RecordError(retrieverSpan, err, "ERROR") + retrieverSpan.End() + return "", fmt.Errorf("search failed: %w", err) + } + + results = searchResults + + // Set output with result summary + tracer.SetOutput(retrieverSpan, fmt.Sprintf("Retrieved %d documents from vector store (duration: %v)", + len(results), duration)) + tracer.RecordSuccess(retrieverSpan, fmt.Sprintf("Vector search completed: %d results", len(results))) + retrieverSpan.End() + } else { + // No tracing, just call search directly + searchResults, err := c.provider.Search(ctx, query, searchOpts) + if err != nil { + return "", fmt.Errorf("search failed: %w", err) + } + results = searchResults } // Format results for display @@ -90,6 +231,17 @@ func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{ return "No relevant context found for query: '" + query + "'", nil } + // Get date filter field for sorting and display (if configured) + dateFilterField := "" + if c.config != nil { + if field, ok := c.config["date_filter_field"].(string); ok && field != "" { + dateFilterField = field + } + } + + // TODO: Add reranking step here in the future + sortResultsByDate(results, dateFilterField) + // Build response string var response strings.Builder response.WriteString(fmt.Sprintf("Found %d relevant context(s) for '%s':\n", len(results), query)) @@ -106,6 +258,13 @@ func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{ response.WriteString("\n") } + // Add metadata if available (use configured date field) + if dateFilterField != "" { + if date, exists := result.Metadata[dateFilterField]; exists { + response.WriteString(fmt.Sprintf("Date: %s\n", date)) + } + } + // Add content response.WriteString(fmt.Sprintf("Content: %s\n", result.Content)) @@ -118,6 +277,37 @@ func (c *Client) handleRAGSearch(ctx context.Context, args map[string]interface{ return response.String(), nil } +// sortResultsByDate sorts results by the configured date field in descending order (newest first) +// If dateField is empty, no sorting is performed +// Uses sort.Slice for O(n log n) performance +func sortResultsByDate(results []SearchResult, dateField string) { + if dateField == "" { + fmt.Printf("[Sort] Skipping sort - dateField is empty\n") + return // Skip sorting if no date field configured + } + + fmt.Printf("[Sort] Sorting %d results by field '%s'\n", len(results), dateField) + + // Log first 3 dates before sorting + for i := 0; i < len(results) && i < 3; i++ { + date := results[i].Metadata[dateField] + fmt.Printf("[Sort] Before[%d]: %s (source: %s)\n", i, date, results[i].FileName) + } + + sort.Slice(results, func(i, j int) bool { + dateI := results[i].Metadata[dateField] + dateJ := results[j].Metadata[dateField] + return dateI > dateJ // Descending order (newest first) + }) + + // Log first 3 dates after sorting + fmt.Printf("[Sort] After sorting:\n") + for i := 0; i < len(results) && i < 3; i++ { + date := results[i].Metadata[dateField] + fmt.Printf("[Sort] After[%d]: %s (source: %s)\n", i, date, results[i].FileName) + } +} + // handleRAGIngest processes document ingestion requests func (c *Client) handleRAGIngest(ctx context.Context, args map[string]interface{}) (string, error) { // Extract file path parameter diff --git a/internal/rag/date_utils.go b/internal/rag/date_utils.go new file mode 100644 index 0000000..648bd70 --- /dev/null +++ b/internal/rag/date_utils.go @@ -0,0 +1,34 @@ +package rag + +import ( + "time" +) + +// ExpandDateRange takes a date string in YYYY-MM-DD format and returns an array +// of date strings spanning backwards from the given date for the specified number of days. +// Example: ExpandDateRange("2025-10-14", 7) returns ["2025-10-14", "2025-10-13", ..., "2025-10-08"] +func ExpandDateRange(dateStr string, days int) ([]string, error) { + if dateStr == "" { + return nil, nil + } + + // Parse the date string + targetDate, err := time.Parse("2006-01-02", dateStr) + if err != nil { + return nil, err + } + + // Generate array from target date backwards + dateArray := make([]string, 0, days) + for i := 0; i < days; i++ { + date := targetDate.AddDate(0, 0, -i) + dateArray = append(dateArray, date.Format("2006-01-02")) + } + + return dateArray, nil +} + +// GetTodayDate returns today's date in YYYY-MM-DD format +func GetTodayDate() string { + return time.Now().Format("2006-01-02") +} diff --git a/internal/rag/date_utils_test.go b/internal/rag/date_utils_test.go new file mode 100644 index 0000000..9519bf0 --- /dev/null +++ b/internal/rag/date_utils_test.go @@ -0,0 +1,181 @@ +package rag + +import ( + "testing" + "time" +) + +// TestExpandDateRange tests the date range expansion function +func TestExpandDateRange(t *testing.T) { + tests := []struct { + name string + date string + days int + wantLen int + wantErr bool + wantLast string // Expected last date in range + }{ + { + name: "7 days from 2025-10-31", + date: "2025-10-31", + days: 7, + wantLen: 7, + wantErr: false, + wantLast: "2025-10-25", + }, + { + name: "3 days from 2025-10-15", + date: "2025-10-15", + days: 3, + wantLen: 3, + wantErr: false, + wantLast: "2025-10-13", + }, + { + name: "1 day (same date)", + date: "2025-10-31", + days: 1, + wantLen: 1, + wantErr: false, + wantLast: "2025-10-31", + }, + { + name: "invalid date format", + date: "2025/10/31", + days: 7, + wantErr: true, + }, + { + name: "empty date", + date: "", + days: 7, + wantLen: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ExpandDateRange(tt.date, tt.days) + + if tt.wantErr { + if err == nil { + t.Errorf("ExpandDateRange() expected error, got none") + } + return + } + + if err != nil { + t.Errorf("ExpandDateRange() unexpected error = %v", err) + return + } + + if len(got) != tt.wantLen { + t.Errorf("ExpandDateRange() length = %d, want %d", len(got), tt.wantLen) + } + + if tt.wantLen > 0 { + // First date should be the input date + if got[0] != tt.date { + t.Errorf("ExpandDateRange() first date = %s, want %s", got[0], tt.date) + } + + // Last date should match expected + if tt.wantLast != "" && got[len(got)-1] != tt.wantLast { + t.Errorf("ExpandDateRange() last date = %s, want %s", got[len(got)-1], tt.wantLast) + } + + // Dates should be in descending order (newest first) + for i := 1; i < len(got); i++ { + if got[i] >= got[i-1] { + t.Errorf("ExpandDateRange() dates not in descending order: %v", got) + break + } + } + + t.Logf("Date range: %v", got) + } + }) + } +} + +// TestGetTodayDate tests the today's date function +func TestGetTodayDate(t *testing.T) { + got := GetTodayDate() + + // Verify format YYYY-MM-DD + _, err := time.Parse("2006-01-02", got) + if err != nil { + t.Errorf("GetTodayDate() returned invalid format: %s, error: %v", got, err) + } + + // Verify it's today's date + expected := time.Now().Format("2006-01-02") + if got != expected { + t.Errorf("GetTodayDate() = %s, want %s", got, expected) + } + + t.Logf("Today's date: %s", got) +} + +// TestExpandDateRange_EdgeCases tests edge cases +func TestExpandDateRange_EdgeCases(t *testing.T) { + t.Run("month boundary", func(t *testing.T) { + // Starting from Oct 3, going back 7 days should cross into September + got, err := ExpandDateRange("2025-10-03", 7) + if err != nil { + t.Fatalf("ExpandDateRange() error = %v", err) + } + + if len(got) != 7 { + t.Errorf("Expected 7 dates, got %d", len(got)) + } + + // Should include dates from September + lastDate := got[len(got)-1] + if lastDate != "2025-09-27" { + t.Errorf("Expected last date to be 2025-09-27, got %s", lastDate) + } + + t.Logf("Month boundary range: %v", got) + }) + + t.Run("year boundary", func(t *testing.T) { + // Starting from Jan 3, going back 7 days should cross into previous year + got, err := ExpandDateRange("2025-01-03", 7) + if err != nil { + t.Fatalf("ExpandDateRange() error = %v", err) + } + + // Should include dates from December 2024 + lastDate := got[len(got)-1] + if lastDate != "2024-12-28" { + t.Errorf("Expected last date to be 2024-12-28, got %s", lastDate) + } + + t.Logf("Year boundary range: %v", got) + }) + + t.Run("leap year", func(t *testing.T) { + // Testing around Feb 29 in a leap year + got, err := ExpandDateRange("2024-03-01", 5) + if err != nil { + t.Fatalf("ExpandDateRange() error = %v", err) + } + + // Should include Feb 29, 2024 (leap day) + found := false + for _, date := range got { + if date == "2024-02-29" { + found = true + break + } + } + + if !found { + t.Errorf("Expected to find 2024-02-29 (leap day) in range: %v", got) + } + + t.Logf("Leap year range: %v", got) + }) +} diff --git a/internal/rag/embedding_factory.go b/internal/rag/embedding_factory.go new file mode 100644 index 0000000..37beb43 --- /dev/null +++ b/internal/rag/embedding_factory.go @@ -0,0 +1,25 @@ +package rag + +import ( + "fmt" +) + +// EmbeddingProviderConfig contains embedding provider-specific settings +type EmbeddingProviderConfig struct { + APIKey string `json:"apiKey,omitempty"` // API key for the embedding provider +} + +// CreateEmbeddingProvider creates an embedding provider based on the provider name and config +// Supports: voyage, openai (future), cohere (future), etc. +func CreateEmbeddingProvider(providerName string, config EmbeddingProviderConfig) (EmbeddingProvider, error) { + switch providerName { + case "voyage": + if config.APIKey == "" { + return nil, fmt.Errorf("API key is required for Voyage embedding provider") + } + return NewVoyageClient(config.APIKey), nil + + default: + return nil, fmt.Errorf("unsupported embedding provider: %s (supported: voyage)", providerName) + } +} diff --git a/internal/rag/provider_interface.go b/internal/rag/provider_interface.go index b95834a..d3f8c47 100644 --- a/internal/rag/provider_interface.go +++ b/internal/rag/provider_interface.go @@ -28,6 +28,14 @@ type VectorProvider interface { GetStats(ctx context.Context) (*VectorStoreStats, error) } +// EmbeddingProvider defines the interface for embedding model providers +// This abstraction allows switching between different embedding providers (Voyage, OpenAI, Cohere, etc.) +type EmbeddingProvider interface { + // EmbedQuery generates embeddings for a query string + // Returns embedding result with vector, token usage, and model info + EmbedQuery(ctx context.Context, query string) (*EmbeddingResult, error) +} + // FileInfo represents information about a file in the vector store type FileInfo struct { ID string @@ -40,9 +48,11 @@ type FileInfo struct { // SearchOptions configures search parameters type SearchOptions struct { - Limit int // Maximum number of results - MinScore float32 // Minimum relevance score - Metadata map[string]string // Filter by metadata + Limit int // Maximum number of results + MinScore float32 // Minimum relevance score + Metadata map[string]string // Filter by metadata + DateFilter []int // Date range filter (YYYYMMDD integer format) + QueryVector []float32 // Pre-computed query embedding vector } // SearchResult represents a search result from the vector store diff --git a/internal/rag/query_enhancer.go b/internal/rag/query_enhancer.go new file mode 100644 index 0000000..5bfc8d6 --- /dev/null +++ b/internal/rag/query_enhancer.go @@ -0,0 +1,109 @@ +package rag + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/tuannvm/slack-mcp-client/internal/llm" +) + +// MetadataFilters represents the metadata filters extracted from a query +type MetadataFilters struct { + BusinessUnits []string `json:"business_units,omitempty"` + Regions []string `json:"regions,omitempty"` + Dates []int `json:"dates,omitempty"` // List of dates for temporal queries (YYYYMMDD integer format) + Labels []string `json:"labels,omitempty"` +} + +// EnhancedQuery represents the result of query enhancement +type EnhancedQuery struct { + EnhancedQuery string `json:"enhanced_query"` + MetadataFilters MetadataFilters `json:"metadata_filters"` + OriginalQuery string `json:"-"` // Not from LLM response +} + +// QueryEnhancer enhances queries using LLM +type QueryEnhancer struct { + llmRegistry *llm.ProviderRegistry +} + +// NewQueryEnhancer creates a new query enhancer +func NewQueryEnhancer(llmRegistry *llm.ProviderRegistry) *QueryEnhancer { + return &QueryEnhancer{ + llmRegistry: llmRegistry, + } +} + +// EnhanceQuery enhances a query by extracting metadata filters and improving the query text +func (qe *QueryEnhancer) EnhanceQuery(ctx context.Context, query string, today string, promptTemplate string) (*EnhancedQuery, error) { + // Build the prompt by replacing placeholders + prompt := strings.ReplaceAll(promptTemplate, "{today}", today) + prompt = strings.ReplaceAll(prompt, "{query}", query) + + // Get the primary LLM provider from registry + provider, err := qe.llmRegistry.GetPrimaryProvider() + if err != nil { + return nil, fmt.Errorf("failed to get LLM provider: %w", err) + } + + // Prepare message + messages := []llm.RequestMessage{ + { + Role: "user", + Content: prompt, + }, + } + + // Call LLM with the prompt + response, err := provider.GenerateChatCompletion(ctx, messages, llm.ProviderOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to call LLM: %w", err) + } + + responseText := response.Content + + // Parse the JSON response + var result EnhancedQuery + if err := json.Unmarshal([]byte(responseText), &result); err != nil { + // Try to extract JSON from code blocks if direct parsing fails + responseText = extractJSONFromCodeBlock(responseText) + if err := json.Unmarshal([]byte(responseText), &result); err != nil { + return nil, fmt.Errorf("failed to parse LLM response as JSON: %w, response: %s", err, responseText) + } + } + + // Set the original query + result.OriginalQuery = query + + return &result, nil +} + +// extractJSONFromCodeBlock extracts JSON from markdown code blocks +func extractJSONFromCodeBlock(text string) string { + // Try to find JSON in ```json or ``` code blocks + text = strings.TrimSpace(text) + + // Look for opening fence (```json or ```) + startIdx := strings.Index(text, "```json") + fenceType := "```json" + if startIdx == -1 { + startIdx = strings.Index(text, "```") + fenceType = "```" + } + + // If we found an opening fence, extract content between fences + if startIdx >= 0 { + // Move past the opening fence + text = text[startIdx+len(fenceType):] + text = strings.TrimSpace(text) + + // Find the closing ``` (even if there's text after it) + if endIdx := strings.Index(text, "```"); endIdx >= 0 { + text = text[:endIdx] + } + } + + return strings.TrimSpace(text) +} diff --git a/internal/rag/query_enhancer_test.go b/internal/rag/query_enhancer_test.go new file mode 100644 index 0000000..7fdaf0b --- /dev/null +++ b/internal/rag/query_enhancer_test.go @@ -0,0 +1,179 @@ +package rag + +import ( + "context" + "os" + "testing" + + "github.com/tuannvm/slack-mcp-client/internal/common/logging" + "github.com/tuannvm/slack-mcp-client/internal/config" + "github.com/tuannvm/slack-mcp-client/internal/llm" +) + +// createTestLLMRegistry creates a real LLM registry for testing +func createTestLLMRegistry(t *testing.T) *llm.ProviderRegistry { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY not set, skipping integration test") + } + + // Create a minimal config for Anthropic + cfg := &config.Config{ + LLM: config.LLMConfig{ + Provider: "anthropic", + Providers: map[string]config.LLMProviderConfig{ + "anthropic": { + Model: "claude-sonnet-4-5-20250929", + APIKey: apiKey, + }, + }, + }, + } + + logger := logging.New("test", logging.LevelInfo) + registry, err := llm.NewProviderRegistry(cfg, logger) + if err != nil { + t.Fatalf("Failed to create LLM registry: %v", err) + } + + return registry +} + +// TestQueryEnhancer_EnhanceQuery_Temporal tests temporal query enhancement with real Claude +// Requires ANTHROPIC_API_KEY environment variable +func TestQueryEnhancer_EnhanceQuery_Temporal(t *testing.T) { + t.Skip("Test requires prompt template file - skipping for now") + + registry := createTestLLMRegistry(t) + enhancer := NewQueryEnhancer(registry) + + ctx := context.Background() + // Note: In real usage, prompt template would be loaded from file + promptTemplate := "Test prompt with {today} and {query} placeholders" + result, err := enhancer.EnhanceQuery(ctx, "What were Q3 2025 revenues for VX in APAC?", "2025-10-31", promptTemplate) + + if err != nil { + t.Fatalf("EnhanceQuery() error = %v", err) + } + + // Verify enhanced query + if result.EnhancedQuery == "" { + t.Errorf("EnhanceQuery() returned empty enhanced query") + } + + t.Logf("Original query: %s", result.OriginalQuery) + t.Logf("Enhanced query: %s", result.EnhancedQuery) + + // Verify metadata filters for temporal query + if len(result.MetadataFilters.Dates) == 0 { + t.Logf("WARNING: No dates returned (expected for temporal query)") + } else { + t.Logf("Dates: %v", result.MetadataFilters.Dates) + } + + t.Logf("Business units: %v", result.MetadataFilters.BusinessUnits) + t.Logf("Regions: %v", result.MetadataFilters.Regions) + t.Logf("Labels: %v", result.MetadataFilters.Labels) +} + +// TestQueryEnhancer_EnhanceQuery_NonTemporal tests non-temporal query enhancement with real Claude +// Requires ANTHROPIC_API_KEY environment variable +func TestQueryEnhancer_EnhanceQuery_NonTemporal(t *testing.T) { + t.Skip("Test requires prompt template file - skipping for now") + + registry := createTestLLMRegistry(t) + enhancer := NewQueryEnhancer(registry) + + ctx := context.Background() + promptTemplate := "Test prompt with {today} and {query} placeholders" + result, err := enhancer.EnhanceQuery(ctx, "What is ROAS?", "2025-10-31", promptTemplate) + + if err != nil { + t.Fatalf("EnhanceQuery() error = %v", err) + } + + t.Logf("Original query: %s", result.OriginalQuery) + t.Logf("Enhanced query: %s", result.EnhancedQuery) + + // Verify no date for non-temporal (knowledge) query + if len(result.MetadataFilters.Dates) > 0 { + t.Logf("WARNING: Dates returned for non-temporal query: %v", result.MetadataFilters.Dates) + } else { + t.Logf("Correctly returned empty dates for non-temporal query") + } + + t.Logf("Labels: %v", result.MetadataFilters.Labels) +} + +// TestQueryEnhancer_EnhanceQuery_RecentQuery tests "recent" keyword handling +// Requires ANTHROPIC_API_KEY environment variable +func TestQueryEnhancer_EnhanceQuery_RecentQuery(t *testing.T) { + t.Skip("Test requires prompt template file - skipping for now") + + registry := createTestLLMRegistry(t) + enhancer := NewQueryEnhancer(registry) + + ctx := context.Background() + promptTemplate := "Test prompt with {today} and {query} placeholders" + result, err := enhancer.EnhanceQuery(ctx, "recent sales performance", "2025-10-31", promptTemplate) + + if err != nil { + t.Fatalf("EnhanceQuery() error = %v", err) + } + + t.Logf("Original query: %s", result.OriginalQuery) + t.Logf("Enhanced query: %s", result.EnhancedQuery) + + // "recent" should trigger temporal behavior + if len(result.MetadataFilters.Dates) == 0 { + t.Logf("WARNING: 'recent' query should return dates") + } else { + t.Logf("Dates for 'recent' query: %v", result.MetadataFilters.Dates) + } + + t.Logf("Labels: %v", result.MetadataFilters.Labels) +} + +// TestExtractJSONFromCodeBlock tests the JSON extraction utility +func TestExtractJSONFromCodeBlock(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "json code block", + input: "```json\n{\"key\": \"value\"}\n```", + want: "{\"key\": \"value\"}", + }, + { + name: "plain code block", + input: "```\n{\"key\": \"value\"}\n```", + want: "{\"key\": \"value\"}", + }, + { + name: "no code block", + input: "{\"key\": \"value\"}", + want: "{\"key\": \"value\"}", + }, + { + name: "with explanation after code block", + input: "```json\n{\"key\": \"value\"}\n```\n\n**Reasoning:**\nSome explanation", + want: "{\"key\": \"value\"}", + }, + { + name: "with text before and after", + input: "Here's the result:\n```json\n{\"key\": \"value\"}\n```\nDone!", + want: "{\"key\": \"value\"}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractJSONFromCodeBlock(tt.input) + if got != tt.want { + t.Errorf("extractJSONFromCodeBlock() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/rag/s3_provider.go b/internal/rag/s3_provider.go new file mode 100644 index 0000000..06d7e44 --- /dev/null +++ b/internal/rag/s3_provider.go @@ -0,0 +1,240 @@ +package rag + +import ( + "context" + "fmt" + "sync" + + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3vectors" + "github.com/aws/aws-sdk-go-v2/service/s3vectors/document" + "github.com/aws/aws-sdk-go-v2/service/s3vectors/types" + "github.com/tuannvm/slack-mcp-client/internal/common/logging" +) + +// S3Provider implements VectorProvider using AWS S3 as storage backend +type S3Provider struct { + bucketName string + indexName string + region string + config map[string]interface{} + s3vectorsClient *s3vectors.Client + logger *logging.Logger + initOnce sync.Once // Ensures initialization happens exactly once + initErr error // Stores initialization error +} + +// NewS3Provider creates a new S3-based vector provider +func NewS3Provider(config map[string]interface{}) (VectorProvider, error) { + bucketName, ok := config["bucket_name"].(string) + if !ok || bucketName == "" { + return nil, fmt.Errorf("bucket_name is required in S3 provider config") + } + + indexName, ok := config["index_name"].(string) + if !ok || indexName == "" { + indexName = "default" // default index name + } + + region, ok := config["region"].(string) + if !ok || region == "" { + region = "us-east-1" // default region + } + + // Create logger for S3 provider + logger := logging.New("s3-provider", logging.LevelInfo) + + return &S3Provider{ + bucketName: bucketName, + indexName: indexName, + region: region, + config: config, + logger: logger, + }, nil +} + +// Initialize sets up the S3 vector provider +// Thread-safe: Uses sync.Once to ensure initialization happens exactly once +func (s *S3Provider) Initialize(ctx context.Context) error { + s.initOnce.Do(func() { + cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(s.region)) + if err != nil { + s.initErr = fmt.Errorf("failed to load AWS config: %w", err) + return + } + s.s3vectorsClient = s3vectors.NewFromConfig(cfg) + }) + return s.initErr +} + +// IngestFile ingests a single file into the vector store +func (s *S3Provider) IngestFile(ctx context.Context, filePath string, metadata map[string]string) (string, error) { + // TODO: Upload file to S3, process and vectorize content + return "", fmt.Errorf("not implemented") +} + +// IngestFiles ingests multiple files into the vector store +func (s *S3Provider) IngestFiles(ctx context.Context, filePaths []string, metadata map[string]string) ([]string, error) { + // TODO: Batch upload files to S3, process and vectorize content + return nil, fmt.Errorf("not implemented") +} + +// DeleteFile removes a file from the vector store +func (s *S3Provider) DeleteFile(ctx context.Context, fileID string) error { + // TODO: Delete file from S3 and remove vectors + return fmt.Errorf("not implemented") +} + +// ListFiles lists files in the vector store +func (s *S3Provider) ListFiles(ctx context.Context, limit int) ([]FileInfo, error) { + // TODO: List files from S3 bucket + return nil, fmt.Errorf("not implemented") +} + +// Search performs a vector similarity search +func (s *S3Provider) Search(ctx context.Context, query string, options SearchOptions) ([]SearchResult, error) { + if s.s3vectorsClient == nil { + return nil, fmt.Errorf("s3 vectors client not initialized") + } + + // S3 provider requires pre-computed query vector + if len(options.QueryVector) == 0 { + return nil, fmt.Errorf("query vector is required in SearchOptions for S3 provider") + } + + // Set default limit if not specified + limit := int32(options.Limit) + if limit <= 0 { + limit = 7 + } + + // Build the query input + input := &s3vectors.QueryVectorsInput{ + VectorBucketName: &s.bucketName, + IndexName: &s.indexName, + QueryVector: &types.VectorDataMemberFloat32{Value: options.QueryVector}, + TopK: &limit, + ReturnDistance: true, + ReturnMetadata: true, + } + + // Build filter from options (generic - caller provides business logic) + if len(options.DateFilter) > 0 || len(options.Metadata) > 0 { + filter := make(map[string]interface{}) + + // Add date filter ONLY if date_filter_field is configured + if len(options.DateFilter) > 0 && s.config != nil { + if dateFilterField, ok := s.config["date_filter_field"].(string); ok && dateFilterField != "" { + filter[dateFilterField] = map[string]interface{}{ + "$in": options.DateFilter, + } + s.logger.DebugKV("Applying date filter", "field", dateFilterField, "dates", options.DateFilter) + } else { + // If date_filter_field is not configured, skip date filtering entirely + s.logger.InfoKV("Date filter not applied: date_filter_field not configured", "provided_dates", options.DateFilter) + } + } + + // Add other metadata filters + for key, value := range options.Metadata { + filter[key] = value + } + + // Wrap in document.NewLazyDocument for AWS SDK + input.Filter = document.NewLazyDocument(filter) + } + + // Execute the vector query + output, err := s.s3vectorsClient.QueryVectors(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to query vectors: %w", err) + } + + // Convert results to SearchResult format + results := make([]SearchResult, 0, len(output.Vectors)) + for _, vector := range output.Vectors { + // Calculate score from distance (assuming lower distance = higher score) + score := float32(1.0) + if vector.Distance != nil { + // Convert distance to similarity score (inverse relationship) + // You may want to adjust this formula based on your distance metric + score = 1.0 / (1.0 + *vector.Distance) + } + + searchResult := SearchResult{ + Score: score, + FileID: *vector.Key, + FileName: *vector.Key, // Using Key as filename for now + Metadata: make(map[string]string), + } + + // Extract content and metadata from S3 response + if vector.Metadata != nil { + // Use Smithy document Unmarshaler to convert to map + var metadataMap map[string]interface{} + err := vector.Metadata.UnmarshalSmithyDocument(&metadataMap) + if err == nil { + // Extract source_text as content + if sourceText, exists := metadataMap["source_text"]; exists { + if text, ok := sourceText.(string); ok { + searchResult.Content = text + } + } + + // Convert metadata to string map + for key, value := range metadataMap { + if key == "source_text" { + continue // Skip source_text as it's already in Content + } + // Convert value to string + searchResult.Metadata[key] = fmt.Sprintf("%v", value) + } + + // Log doc_id and date field (if configured) + vectorKey := *vector.Key + if dateField, ok := s.config["date_filter_field"].(string); ok && dateField != "" { + if reportDate, exists := searchResult.Metadata[dateField]; exists { + s.logger.DebugKV("S3 vector result", "vector_key", vectorKey, "date_field", dateField, "date_value", reportDate, "score", score) + } else { + s.logger.DebugKV("S3 vector result", "vector_key", vectorKey, "score", score) + } + } else { + s.logger.DebugKV("S3 vector result", "vector_key", vectorKey, "score", score) + } + } else { + // Failed to unmarshal metadata - skip this result + s.logger.ErrorKV("Failed to unmarshal vector metadata, skipping result", + "vector_key", *vector.Key, + "score", score, + "error", err) + continue + } + } + + // Apply minimum score filter + if options.MinScore > 0 && searchResult.Score < options.MinScore { + continue + } + + results = append(results, searchResult) + } + + return results, nil +} + +// Close cleans up resources +func (s *S3Provider) Close() error { + // TODO: Clean up S3 client and connections + return nil +} + +// GetStats returns statistics about the vector store +func (s *S3Provider) GetStats(ctx context.Context) (*VectorStoreStats, error) { + // TODO: Gather stats from S3 bucket + return &VectorStoreStats{}, fmt.Errorf("not implemented") +} + +func init() { + // Register the S3 provider factory + RegisterVectorProvider("s3", NewS3Provider) +} diff --git a/internal/rag/s3_provider_test.go b/internal/rag/s3_provider_test.go new file mode 100644 index 0000000..1d686b9 --- /dev/null +++ b/internal/rag/s3_provider_test.go @@ -0,0 +1,250 @@ +package rag + +import ( + "context" + "os" + "testing" +) + +// TestS3Provider_Search tests the S3 vector search functionality +// This is an integration test that requires: +// - AWS_REGION, S3_VECTOR_BUCKET, S3_VECTOR_INDEX environment variables +// - AWS credentials configured (via AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or IAM role) +// - An existing S3 vector store with indexed data +func TestS3Provider_Search(t *testing.T) { + bucketName := os.Getenv("S3_VECTOR_BUCKET") + indexName := os.Getenv("S3_VECTOR_INDEX") + region := os.Getenv("AWS_REGION") + + if bucketName == "" || indexName == "" { + t.Skip("S3_VECTOR_BUCKET or S3_VECTOR_INDEX not set, skipping integration test") + } + + if region == "" { + region = "us-east-1" + } + + // Create S3 provider + config := map[string]interface{}{ + "bucket_name": bucketName, + "index_name": indexName, + "region": region, + } + + provider, err := NewS3Provider(config) + if err != nil { + t.Fatalf("NewS3Provider() error = %v", err) + } + + // Initialize the provider + ctx := context.Background() + err = provider.Initialize(ctx) + if err != nil { + t.Fatalf("Initialize() error = %v", err) + } + + // Test vector search + // Note: This requires a pre-computed query vector + // In real usage, this would come from Voyage embeddings + tests := []struct { + name string + query string + queryVector []float32 + options SearchOptions + wantErr bool + }{ + { + name: "search with no vector should fail", + query: "test query", + options: SearchOptions{ + Limit: 5, + }, + wantErr: true, + }, + // TODO: Add test with actual query vector once we have sample data + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.options.QueryVector = tt.queryVector + + results, err := provider.Search(ctx, tt.query, tt.options) + + if tt.wantErr { + if err == nil { + t.Errorf("Search() expected error, got none") + } + return + } + + if err != nil { + t.Errorf("Search() error = %v", err) + return + } + + t.Logf("Search returned %d results", len(results)) + + // Verify results structure + for i, result := range results { + if result.FileID == "" { + t.Errorf("Result %d has empty FileID", i) + } + if result.Score < 0 { + t.Errorf("Result %d has negative score: %f", i, result.Score) + } + t.Logf("Result %d: FileID=%s, Score=%.4f, Content length=%d", + i, result.FileID, result.Score, len(result.Content)) + } + }) + } +} + +// TestS3Provider_Search_WithFilters tests metadata filtering +func TestS3Provider_Search_WithFilters(t *testing.T) { + bucketName := os.Getenv("S3_VECTOR_BUCKET") + indexName := os.Getenv("S3_VECTOR_INDEX") + voyageAPIKey := os.Getenv("VOYAGE_API_KEY") + + if bucketName == "" || indexName == "" { + t.Skip("S3_VECTOR_BUCKET or S3_VECTOR_INDEX not set, skipping integration test") + } + + if voyageAPIKey == "" { + t.Skip("VOYAGE_API_KEY not set, skipping integration test") + } + + // Create Voyage client to generate real embeddings + voyageClient := NewVoyageClient(voyageAPIKey) + ctx := context.Background() + + // Generate query embedding using Voyage + embeddingResult, err := voyageClient.EmbedQuery(ctx, "revenue performance metrics") + if err != nil { + t.Fatalf("Failed to generate query embedding: %v", err) + } + t.Logf("Generated query embedding with %d dimensions", len(embeddingResult.Embedding)) + + config := map[string]interface{}{ + "bucket_name": bucketName, + "index_name": indexName, + "region": "us-east-1", + } + + provider, err := NewS3Provider(config) + if err != nil { + t.Fatalf("NewS3Provider() error = %v", err) + } + + err = provider.Initialize(ctx) + if err != nil { + t.Fatalf("Initialize() error = %v", err) + } + + // Test with date filter + t.Run("search with date filter", func(t *testing.T) { + dateFilter := []int{20251031, 20251030, 20251029} + + results, err := provider.Search(ctx, "revenue performance metrics", SearchOptions{ + QueryVector: embeddingResult.Embedding, + DateFilter: dateFilter, + Limit: 5, + }) + + if err != nil { + t.Logf("Search with date filter: %v (may fail if no matching data)", err) + } else { + t.Logf("Found %d results with date filter", len(results)) + for _, result := range results { + if date, exists := result.Metadata["report_generated_date"]; exists { + t.Logf(" - Date: %s", date) + } + } + } + }) + + // Test with metadata filter + t.Run("search with metadata filter", func(t *testing.T) { + results, err := provider.Search(ctx, "revenue performance metrics", SearchOptions{ + QueryVector: embeddingResult.Embedding, + Metadata: map[string]string{ + "business_units": "VX", + }, + Limit: 5, + }) + + if err != nil { + t.Logf("Search with metadata filter: %v (may fail if no matching data)", err) + } else { + t.Logf("Found %d results with business unit filter", len(results)) + } + }) + + // Test with no filter + t.Run("search with no filter", func(t *testing.T) { + results, err := provider.Search(ctx, "revenue performance metrics", SearchOptions{ + QueryVector: embeddingResult.Embedding, + Limit: 5, + }) + + if err != nil { + t.Logf("Search with no metadata filter: %v (may fail if no matching data)", err) + } else { + t.Logf("Found %d results with no filter", len(results)) + } + }) +} + +// TestS3Provider_Config tests configuration validation +func TestS3Provider_Config(t *testing.T) { + tests := []struct { + name string + config map[string]interface{} + wantErr bool + }{ + { + name: "valid config", + config: map[string]interface{}{ + "bucket_name": "test-bucket", + "index_name": "test-index", + "region": "us-east-1", + }, + wantErr: false, + }, + { + name: "missing bucket_name", + config: map[string]interface{}{ + "region": "us-east-1", + }, + wantErr: true, + }, + { + name: "default region", + config: map[string]interface{}{ + "bucket_name": "test-bucket", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewS3Provider(tt.config) + + if tt.wantErr { + if err == nil { + t.Errorf("NewS3Provider() expected error, got none") + } + return + } + + if err != nil { + t.Errorf("NewS3Provider() error = %v", err) + return + } + + if provider == nil { + t.Errorf("NewS3Provider() returned nil provider") + } + }) + } +} diff --git a/internal/rag/voyage_client.go b/internal/rag/voyage_client.go new file mode 100644 index 0000000..6159c5b --- /dev/null +++ b/internal/rag/voyage_client.go @@ -0,0 +1,143 @@ +package rag + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const ( + voyageAPIURL = "https://api.voyageai.com/v1/contextualizedembeddings" + voyageModel = "voyage-context-3" +) + +// VoyageClient is a client for the Voyage AI contextualized embeddings API +type VoyageClient struct { + apiKey string + httpClient *http.Client +} + +// NewVoyageClient creates a new Voyage AI client +func NewVoyageClient(apiKey string) *VoyageClient { + return &VoyageClient{ + apiKey: apiKey, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +// voyageContextualEmbedRequest represents the request payload for contextualized embeddings +type voyageContextualEmbedRequest struct { + Inputs [][]string `json:"inputs"` // List of lists: outer=batch, inner=context + InputType string `json:"input_type"` // "query" or "document" + Model string `json:"model"` +} + +// voyageEmbeddingItem represents a single embedding result +type voyageEmbeddingItem struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` +} + +// voyageDataItem represents a data item in the response +type voyageDataItem struct { + Object string `json:"object"` + Data []voyageEmbeddingItem `json:"data"` + Index int `json:"index"` +} + +// voyageContextualEmbedResponse represents the response from Voyage API +type voyageContextualEmbedResponse struct { + Object string `json:"object"` + Data []voyageDataItem `json:"data"` + Model string `json:"model"` + Usage struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +// EmbeddingResult contains the embedding vector and token usage +type EmbeddingResult struct { + Embedding []float32 + TokensUsed int + Model string +} + +// EmbedQuery embeds a query string using Voyage's contextualized embeddings API +// Returns the embedding vector and token usage +func (c *VoyageClient) EmbedQuery(ctx context.Context, query string) (*EmbeddingResult, error) { + // Prepare request payload + // inputs is a list of lists: [["query"]] for single query + reqBody := voyageContextualEmbedRequest{ + Inputs: [][]string{{query}}, + InputType: "query", + Model: voyageModel, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", voyageAPIURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + + // Execute request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("voyage API returned status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var voyageResp voyageContextualEmbedResponse + if err := json.Unmarshal(body, &voyageResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Extract embedding from response structure + // Response structure: data[0].data[0].embedding + if len(voyageResp.Data) == 0 { + return nil, fmt.Errorf("empty data array in response") + } + if len(voyageResp.Data[0].Data) == 0 { + return nil, fmt.Errorf("empty embedding data in response") + } + + embedding := voyageResp.Data[0].Data[0].Embedding + if len(embedding) == 0 { + return nil, fmt.Errorf("empty embedding vector") + } + + return &EmbeddingResult{ + Embedding: embedding, + TokensUsed: voyageResp.Usage.TotalTokens, + Model: voyageResp.Model, + }, nil +} diff --git a/internal/rag/voyage_client_test.go b/internal/rag/voyage_client_test.go new file mode 100644 index 0000000..8275cb2 --- /dev/null +++ b/internal/rag/voyage_client_test.go @@ -0,0 +1,80 @@ +package rag + +import ( + "context" + "os" + "testing" +) + +// TestVoyageClient_EmbedQuery tests the Voyage embedding client +// This is an integration test that requires VOYAGE_API_KEY environment variable +func TestVoyageClient_EmbedQuery(t *testing.T) { + apiKey := os.Getenv("VOYAGE_API_KEY") + if apiKey == "" { + t.Skip("VOYAGE_API_KEY not set, skipping integration test") + } + + client := NewVoyageClient(apiKey) + ctx := context.Background() + + tests := []struct { + name string + query string + wantErr bool + }{ + { + name: "simple query", + query: "What is revenue?", + wantErr: false, + }, + { + name: "complex query", + query: "What were the Q3 2025 revenues for the VX business unit in APAC region?", + wantErr: false, + }, + { + name: "empty query should fail", + query: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := client.EmbedQuery(ctx, tt.query) + + if tt.wantErr { + if err == nil { + t.Errorf("EmbedQuery() expected error, got none") + } + return + } + + if err != nil { + t.Errorf("EmbedQuery() error = %v", err) + return + } + + // Verify embedding dimensions (voyage-context-3 should return 1024 dimensions) + if len(result.Embedding) == 0 { + t.Errorf("EmbedQuery() returned empty embedding") + } + + t.Logf("Successfully generated embedding with %d dimensions (model: %s, tokens: %d)", + len(result.Embedding), result.Model, result.TokensUsed) + }) + } +} + +// TestVoyageClient_EmbedQuery_InvalidAPIKey tests error handling with invalid credentials +func TestVoyageClient_EmbedQuery_InvalidAPIKey(t *testing.T) { + client := NewVoyageClient("invalid-api-key") + ctx := context.Background() + + _, err := client.EmbedQuery(ctx, "test query") + if err == nil { + t.Errorf("EmbedQuery() with invalid API key should return error") + } + + t.Logf("Correctly returned error: %v", err) +} diff --git a/internal/slack/client.go b/internal/slack/client.go index c220606..b9207e9 100644 --- a/internal/slack/client.go +++ b/internal/slack/client.go @@ -27,16 +27,18 @@ import ( // Client represents the Slack client application. type Client struct { - logger *logging.Logger // Structured logger - userFrontend UserFrontend - mcpClients map[string]*mcp.Client - llmMCPBridge *handlers.LLMMCPBridge - llmRegistry *llm.ProviderRegistry // LLM provider registry - cfg *config.Config // Holds the application configuration - messageHistory map[string][]Message - historyLimit int - discoveredTools map[string]mcp.ToolInfo - tracingHandler observability.TracingHandler + logger *logging.Logger // Structured logger + userFrontend UserFrontend + mcpClients map[string]*mcp.Client + llmMCPBridge *handlers.LLMMCPBridge + llmRegistry *llm.ProviderRegistry // LLM provider registry + cfg *config.Config // Holds the application configuration + messageHistory map[string][]Message + historyLimit int + discoveredTools map[string]mcp.ToolInfo + tracingHandler observability.TracingHandler + queryEnhancer *rag.QueryEnhancer // Query enhancer for all queries (not just RAG) + queryEnhancementPrompt string // Query enhancement prompt template loaded from file } // Message represents a message in the conversation history @@ -134,6 +136,22 @@ func NewClient(userFrontend UserFrontend, stdLogger *logging.Logger, mcpClients if openaiConfig, exists := cfg.LLM.Providers["openai"]; exists && openaiConfig.APIKey != "" { ragConfig["api_key"] = openaiConfig.APIKey } + case "s3": + if providerSettings.BucketName != "" { + ragConfig["bucket_name"] = providerSettings.BucketName + } + if providerSettings.IndexName != "" { + ragConfig["index_name"] = providerSettings.IndexName + } + if providerSettings.Region != "" { + ragConfig["region"] = providerSettings.Region + } + if providerSettings.MaxResults > 0 { + ragConfig["max_results"] = providerSettings.MaxResults + } + if providerSettings.ScoreThreshold > 0 { + ragConfig["score_threshold"] = providerSettings.ScoreThreshold + } } } @@ -142,6 +160,16 @@ func NewClient(userFrontend UserFrontend, stdLogger *logging.Logger, mcpClients ragConfig["chunk_size"] = cfg.RAG.ChunkSize } + // Set date filter configuration (applies to all providers) + if providerSettings, exists := cfg.RAG.Providers[cfg.RAG.Provider]; exists { + if providerSettings.DateFilterField != "" { + ragConfig["date_filter_field"] = providerSettings.DateFilterField + } + if providerSettings.DateRangeWindowDays > 0 { + ragConfig["date_range_window_days"] = providerSettings.DateRangeWindowDays + } + } + ragClient, err := rag.NewClientWithProvider(cfg.RAG.Provider, ragConfig) if err != nil { clientLogger.ErrorKV("Failed to create RAG client", "error", err) @@ -164,6 +192,81 @@ func NewClient(userFrontend UserFrontend, stdLogger *logging.Logger, mcpClients } clientLogger.Info("LLM provider registry initialized successfully") + // Initialize query enhancer for ALL queries (not just RAG) + var queryEnhancer *rag.QueryEnhancer + var queryEnhancementPrompt string + + if cfg.QueryEnhancementProvider != "" { + // Validate: If queryEnhancementProvider is set, queryEnhancementPromptFile is required + if cfg.QueryEnhancementPromptFile == "" { + clientLogger.ErrorKV("queryEnhancementProvider is configured but queryEnhancementPromptFile is missing", + "provider", cfg.QueryEnhancementProvider) + return nil, customErrors.WrapConfigError( + fmt.Errorf("queryEnhancementPromptFile is required when queryEnhancementProvider is configured"), + "query_enhancement_config_invalid", + "queryEnhancementPromptFile must be set when queryEnhancementProvider is configured") + } + + // Load query enhancement prompt from file + content, err := os.ReadFile(cfg.QueryEnhancementPromptFile) + if err != nil { + clientLogger.ErrorKV("Failed to read query enhancement prompt file", + "file", cfg.QueryEnhancementPromptFile, "error", err) + return nil, customErrors.WrapConfigError(err, "query_enhancement_prompt_file_read_failed", + "Failed to read query enhancement prompt file") + } + queryEnhancementPrompt = string(content) + clientLogger.InfoKV("Loaded query enhancement prompt from file", + "file", cfg.QueryEnhancementPromptFile, "length", len(queryEnhancementPrompt)) + + // Create a separate LLM registry for query enhancement + clientLogger.InfoKV("Creating dedicated LLM registry for query enhancement", "provider", cfg.QueryEnhancementProvider) + + // Create a minimal config with only the specified provider + queryEnhancerConfig := &config.Config{ + LLM: config.LLMConfig{ + Provider: cfg.QueryEnhancementProvider, + Providers: cfg.LLM.Providers, // Reuse all provider configs + }, + } + + queryEnhancerLogger := logging.New("query-enhancer", logLevel) + qeRegistry, err := llm.NewProviderRegistry(queryEnhancerConfig, queryEnhancerLogger) + if err != nil { + clientLogger.ErrorKV("Failed to create query enhancement LLM registry", "error", err) + return nil, customErrors.WrapLLMError(err, "query_enhancement_llm_registry_failed", + "Failed to create LLM registry for query enhancement") + } + queryEnhancer = rag.NewQueryEnhancer(qeRegistry) + clientLogger.InfoKV("Created query enhancer for all queries", "provider", cfg.QueryEnhancementProvider) + } + + // Wire up RAG embedding provider + // Note: Query enhancement is now done in Slack client before LLM call + if ragClient, ok := rawClientMap["rag"].(*rag.Client); ok { + // Create embedding provider if configured + if cfg.RAG.EmbeddingProvider != "" { + // Get embedding provider config + var embeddingConfig rag.EmbeddingProviderConfig + if providerCfg, exists := cfg.RAG.EmbeddingProviders[cfg.RAG.EmbeddingProvider]; exists { + embeddingConfig = rag.EmbeddingProviderConfig{ + APIKey: providerCfg.APIKey, + } + } else { + clientLogger.WarnKV("Embedding provider config not found in RAG.EmbeddingProviders", + "provider", cfg.RAG.EmbeddingProvider) + } + + provider, err := rag.CreateEmbeddingProvider(cfg.RAG.EmbeddingProvider, embeddingConfig) + if err != nil { + clientLogger.ErrorKV("Failed to create embedding provider, continuing without embeddings", "provider", cfg.RAG.EmbeddingProvider, "error", err) + } else { + ragClient.SetEmbeddingProvider(provider) + clientLogger.InfoKV("Enabled RAG embeddings", "provider", cfg.RAG.EmbeddingProvider) + } + } + } + // Load custom prompt from file if specified and customPrompt is empty if cfg.LLM.CustomPromptFile != "" && cfg.LLM.CustomPrompt == "" { content, err := os.ReadFile(cfg.LLM.CustomPromptFile) @@ -189,18 +292,26 @@ func NewClient(userFrontend UserFrontend, stdLogger *logging.Logger, mcpClients // Initialize observability tracingHandler := observability.NewTracingHandler(cfg, clientLogger) + // Wire up tracing handler to RAG client for embedding span tracking + if ragClient, ok := rawClientMap["rag"].(*rag.Client); ok { + ragClient.SetTracingHandler(tracingHandler) + clientLogger.DebugKV("Set tracing handler on RAG client", "client", "rag") + } + // --- Create and return Client instance --- return &Client{ - logger: clientLogger, - userFrontend: userFrontend, - mcpClients: mcpClients, - llmMCPBridge: llmMCPBridge, - llmRegistry: registry, - cfg: cfg, - messageHistory: make(map[string][]Message), - historyLimit: cfg.Slack.MessageHistory, // Store configured number of messages per channel - discoveredTools: discoveredTools, - tracingHandler: tracingHandler, + logger: clientLogger, + userFrontend: userFrontend, + mcpClients: mcpClients, + llmMCPBridge: llmMCPBridge, + llmRegistry: registry, + cfg: cfg, + messageHistory: make(map[string][]Message), + historyLimit: cfg.Slack.MessageHistory, // Store configured number of messages per channel + discoveredTools: discoveredTools, + tracingHandler: tracingHandler, + queryEnhancer: queryEnhancer, // Query enhancer for all queries + queryEnhancementPrompt: queryEnhancementPrompt, // Query enhancement prompt template }, nil } @@ -418,16 +529,73 @@ func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string, timest // Show a temporary "typing" indicator c.userFrontend.SendMessage(channelID, threadTS, c.cfg.Slack.ThinkingMessage) + var enhancedQuery string + var queryMetadata *rag.MetadataFilters + + if c.queryEnhancer != nil { + today := time.Now().Format("2006-01-02") // Format as YYYY-MM-DD + fmt.Printf("[Query Enhancement] INPUT: '%s'\n", userPrompt) + fmt.Printf("[Query Enhancement] Today's date: %s\n", today) + + // Start query enhancement span + qeCtx, qeSpan := c.tracingHandler.StartLLMSpan(ctx, "query-enhancement", + c.cfg.LLM.Providers[c.cfg.QueryEnhancementProvider].Model, + userPrompt, + map[string]interface{}{ + "today": today, + "enhancement_type": "temporal_detection", + }) + + startTime := time.Now() + enhanced, err := c.queryEnhancer.EnhanceQuery(qeCtx, userPrompt, today, c.queryEnhancementPrompt) + duration := time.Since(startTime) + + c.tracingHandler.SetDuration(qeSpan, duration) + + if err != nil { + fmt.Printf("[Query Enhancement] ERROR: %v, using original query\n", err) + c.logger.WarnKV("Query enhancement failed, using original query", "error", err) + c.tracingHandler.RecordError(qeSpan, err, "ERROR") + enhancedQuery = userPrompt + } else { + enhancedQuery = enhanced.EnhancedQuery + queryMetadata = &enhanced.MetadataFilters + + fmt.Printf("[Query Enhancement] OUTPUT: '%s'\n", enhancedQuery) + if queryMetadata != nil && len(queryMetadata.Dates) > 0 { + fmt.Printf("[Query Enhancement] Detected temporal query with %d dates: %v\n", len(queryMetadata.Dates), queryMetadata.Dates) + } else { + fmt.Printf("[Query Enhancement] Non-temporal query (no date metadata)\n") + } + + // Set output and metadata for tracing + c.tracingHandler.SetOutput(qeSpan, enhancedQuery) + if queryMetadata != nil && len(queryMetadata.Dates) > 0 { + c.tracingHandler.RecordSuccess(qeSpan, fmt.Sprintf("Temporal query detected: %d dates", len(queryMetadata.Dates))) + } else { + c.tracingHandler.RecordSuccess(qeSpan, "Non-temporal query") + } + + c.logger.DebugKV("Query enhanced successfully", "enhanced", enhancedQuery, "has_metadata", queryMetadata != nil) + } + qeSpan.End() + } else { + fmt.Printf("[Query Enhancement] DISABLED: Using original query\n") + // No query enhancer configured, use original query + enhancedQuery = userPrompt + } + if !c.cfg.LLM.UseAgent { // Prepare the final prompt with custom prompt as system instruction + // Use ENHANCED query instead of original userPrompt var finalPrompt string customPrompt := c.cfg.LLM.CustomPrompt if customPrompt != "" { - // Use custom prompt as system instruction, then add user prompt - finalPrompt = fmt.Sprintf("System instructions: %s\n\nUser: %s", customPrompt, userPrompt) + // Use custom prompt as system instruction, then add enhanced query + finalPrompt = fmt.Sprintf("System instructions: %s\n\nUser: %s", customPrompt, enhancedQuery) c.logger.DebugKV("Using custom prompt as system instruction", "custom_prompt_length", len(customPrompt)) } else { - finalPrompt = userPrompt + finalPrompt = enhancedQuery } llmCtx, llmSpan := c.tracingHandler.StartLLMSpan(ctx, "llm-call", c.cfg.LLM.Providers[c.cfg.LLM.Provider].Model, finalPrompt, map[string]interface{}{ @@ -473,7 +641,9 @@ func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string, timest llmSpan.End() // Process the LLM response through the MCP pipeline - c.processLLMResponseAndReply(llmCtx, llmResponse, userPrompt, channelID, threadTS) + // Pass enhancedQuery instead of userPrompt so re-prompt uses enhanced query + // Pass queryMetadata so it can be forwarded to RAG search + c.processLLMResponseAndReply(llmCtx, llmResponse, enhancedQuery, queryMetadata, channelID, threadTS) } else { // Agent path with enhanced tracing agentCtx, agentSpan := c.tracingHandler.StartSpan(ctx, "llm-agent-call", "generation", userPrompt, map[string]string{ @@ -616,7 +786,7 @@ func (c *Client) estimateToolTokenUsage(toolName, prompt, response string) int { // processLLMResponseAndReply processes the LLM response, handles tool results with re-prompting, and sends the final reply. // Incorporates logic previously in LLMClient.ProcessToolResponse. -func (c *Client) processLLMResponseAndReply(traceCtx context.Context, llmResponse *llms.ContentChoice, userPrompt, channelID, threadTS string) { +func (c *Client) processLLMResponseAndReply(traceCtx context.Context, llmResponse *llms.ContentChoice, userPrompt string, queryMetadata *rag.MetadataFilters, channelID, threadTS string) { // Start tool processing span ctx, span := c.tracingHandler.StartSpan(traceCtx, "tool-processing", "span", userPrompt, map[string]string{ "channel_id": channelID, @@ -631,11 +801,13 @@ func (c *Client) processLLMResponseAndReply(traceCtx context.Context, llmRespons "channel_id": channelID, "thread_ts": threadTS, } - c.logger.DebugKV("Added extra arguments", "channel_id", channelID, "thread_ts", threadTS) - // Create a context with timeout for tool processing - toolCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() + if queryMetadata != nil { + extraArgs["query_metadata"] = queryMetadata + c.logger.DebugKV("Added query metadata to extra arguments", "date_count", len(queryMetadata.Dates)) + } + + c.logger.DebugKV("Added extra arguments", "channel_id", channelID, "thread_ts", threadTS) // --- Process Tool Response (Logic from LLMClient.ProcessToolResponse) --- var finalResponse string @@ -649,41 +821,56 @@ func (c *Client) processLLMResponseAndReply(traceCtx context.Context, llmRespons toolProcessingErr = nil c.logger.Warn("LLMMCPBridge is nil, skipping tool processing") } else { - // Extract tool name before execution - executedToolName := c.extractToolNameFromResponse(llmResponse.Content) - - // Start tool execution span - _, toolExecSpan := c.tracingHandler.StartSpan(ctx, "tool-execution", "event", "", map[string]string{ - "bridge_available": "true", - "response_type": "processing", - "tool_name": executedToolName, - }) - startTime := time.Now() - // Process the response through the bridge - processedResponse, err := c.llmMCPBridge.ProcessLLMResponse(toolCtx, llmResponse, userPrompt, extraArgs) - toolDuration := time.Since(startTime) - c.tracingHandler.SetDuration(toolExecSpan, toolDuration) + // Extract tool call from LLM response using bridge's logic (single source of truth) + toolCall, err := c.llmMCPBridge.ExtractToolCall(llmResponse) if err != nil { - finalResponse = fmt.Sprintf("Sorry, I encountered an error while trying to use a tool: %v", err) + finalResponse = fmt.Sprintf("Sorry, I encountered an error extracting tool call: %v", err) isToolResult = false - toolProcessingErr = err // Store the error - c.tracingHandler.RecordError(toolExecSpan, err, "ERROR") - } else { - // If the processed response is different from the original, a tool was executed - if processedResponse != llmResponse.Content { + toolProcessingErr = err + c.logger.ErrorKV("Failed to extract tool call", "error", err) + } else if toolCall != nil { + // Tool call detected - execute it with tracing + c.logger.InfoKV("Tool call detected", "tool", toolCall.Tool) + + // Marshal args for tracing + argsJSON, _ := json.Marshal(toolCall.Args) + + // Start tool execution span with tool arguments as input + // IMPORTANT: Use the returned context so child spans (embedding, retriever) are properly nested + toolExecCtx, toolExecSpan := c.tracingHandler.StartSpan(ctx, "tool-execution", "tool", string(argsJSON), map[string]string{ + "bridge_available": "true", + "response_type": "processing", + "tool_name": toolCall.Tool, + }) + + // Create a context with timeout from the span context + toolCtx, cancel := context.WithTimeout(toolExecCtx, 1*time.Minute) + defer cancel() + + startTime := time.Now() + // Execute the tool call + processedResponse, err := c.llmMCPBridge.ExecuteToolCall(toolCtx, toolCall, extraArgs) + toolDuration := time.Since(startTime) + c.tracingHandler.SetDuration(toolExecSpan, toolDuration) + + if err != nil { + finalResponse = fmt.Sprintf("Sorry, I encountered an error while trying to use a tool: %v", err) + isToolResult = false + toolProcessingErr = err + c.tracingHandler.RecordError(toolExecSpan, err, "ERROR") + } else { finalResponse = processedResponse isToolResult = true c.tracingHandler.SetOutput(toolExecSpan, processedResponse) c.tracingHandler.RecordSuccess(toolExecSpan, "Tool executed successfully") - } else { - // No tool was executed - finalResponse = llmResponse.Content - isToolResult = false - c.tracingHandler.SetOutput(toolExecSpan, "No tool execution required") - c.tracingHandler.RecordSuccess(toolExecSpan, "No tool processing needed") } + toolExecSpan.End() + } else { + // No tool call detected - just use LLM response content + finalResponse = llmResponse.Content + isToolResult = false + c.logger.Debug("No tool call detected, using LLM response as-is") } - toolExecSpan.End() } // --- End of Process Tool Response Logic --- @@ -699,7 +886,7 @@ func (c *Client) processLLMResponseAndReply(traceCtx context.Context, llmRespons c.logger.DebugKV("Tool result", "result", logging.TruncateForLog(finalResponse, 500)) // Always re-prompt LLM with tool results for synthesis - // Construct a new prompt incorporating the original prompt and the tool result + // Construct a new prompt incorporating the enhanced query and the tool result rePrompt := fmt.Sprintf("The user asked: '%s'\n\nI searched the knowledge base and found the following relevant information:\n```\n%s\n```\n\nPlease analyze and synthesize this retrieved information to provide a comprehensive response to the user's request. Use the detailed information from the search results according to your system instructions.", userPrompt, finalResponse) // Start re-prompt span