diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go index 9be025c29..ba73af4dd 100644 --- a/internal/auth/kiro/aws.go +++ b/internal/auth/kiro/aws.go @@ -40,6 +40,10 @@ type KiroTokenData struct { ClientSecret string `json:"clientSecret,omitempty"` // Email is the user's email address (used for file naming) Email string `json:"email,omitempty"` + // StartURL is the IDC/Identity Center start URL (only for IDC auth method) + StartURL string `json:"startUrl,omitempty"` + // Region is the AWS region for IDC authentication (only for IDC auth method) + Region string `json:"region,omitempty"` } // KiroAuthBundle aggregates authentication data after OAuth flow completion diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 2c9150f1b..ab44e55f6 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -2,16 +2,19 @@ package kiro import ( + "bufio" "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "html" "io" "net" "net/http" + "os" "strings" "time" @@ -24,19 +27,31 @@ import ( const ( // AWS SSO OIDC endpoints ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" - + // Kiro's start URL for Builder ID builderIDStartURL = "https://view.awsapps.com/start" - + + // Default region for IDC + defaultIDCRegion = "us-east-1" + // Polling interval pollInterval = 5 * time.Second - + // Authorization code flow callback authCodeCallbackPath = "/oauth/callback" authCodeCallbackPort = 19877 - + // User-Agent to match official Kiro IDE kiroUserAgent = "KiroIDE" + + // IDC token refresh headers (matching Kiro IDE behavior) + idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" +) + +// Sentinel errors for OIDC token polling +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrSlowDown = errors.New("slow_down") ) // SSOOIDCClient handles AWS SSO OIDC authentication. @@ -83,6 +98,440 @@ type CreateTokenResponse struct { RefreshToken string `json:"refreshToken"` } +// getOIDCEndpoint returns the OIDC endpoint for the given region. +func getOIDCEndpoint(region string) string { + if region == "" { + region = defaultIDCRegion + } + return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) +} + +// promptInput prompts the user for input with an optional default value. +func promptInput(prompt, defaultValue string) string { + reader := bufio.NewReader(os.Stdin) + if defaultValue != "" { + fmt.Printf("%s [%s]: ", prompt, defaultValue) + } else { + fmt.Printf("%s: ", prompt) + } + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return defaultValue + } + input = strings.TrimSpace(input) + if input == "" { + return defaultValue + } + return input +} + +// promptSelect prompts the user to select from options using number input. +func promptSelect(prompt string, options []string) int { + reader := bufio.NewReader(os.Stdin) + + for { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Printf("Enter selection (1-%d): ", len(options)) + + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return 0 // Default to first option on error + } + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) + continue + } + return selection - 1 + } +} + +// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. +func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. +func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": startURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateTokenWithRegion polls for the access token after user authorization using a specific region. +func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, ErrAuthorizationPending + } + if errResp.Error == "slow_down" { + return nil, ErrSlowDown + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region. +func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + + // Set headers matching kiro2api's IDC token refresh + // These headers are required for successful IDC token refresh + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) + req.Header.Set("Connection", "keep-alive") + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) + req.Header.Set("Accept", "*/*") + req.Header.Set("Accept-Language", "*") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("User-Agent", "node") + req.Header.Set("Accept-Encoding", "br, gzip, deflate") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + StartURL: startURL, + Region: region, + }, nil +} + +// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC). +func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Identity Center) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client with the specified region + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClientWithRegion(ctx, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization with IDC start URL + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Confirm the following code in the browser:\n") + fmt.Printf(" Code: %s\n", authResp.UserCode) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete) + + // Set incognito mode based on config + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) + if err != nil { + if errors.Is(err, ErrAuthorizationPending) { + fmt.Print(".") + continue + } + if errors.Is(err, ErrSlowDown) { + interval += 5 * time.Second + continue + } + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Fetch user email + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + StartURL: startURL, + Region: region, + }, nil + } + } + + // Close browser on timeout + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. +func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Prompt for login method + options := []string{ + "Use with Builder ID (personal AWS account)", + "Use with IDC Account (organization SSO)", + } + selection := promptSelect("\n? Select login method:", options) + + if selection == 0 { + // Builder ID flow - use existing implementation + return c.LoginWithBuilderID(ctx) + } + + // IDC flow - prompt for start URL and region + fmt.Println() + startURL := promptInput("? Enter Start URL", "") + if startURL == "" { + return nil, fmt.Errorf("start URL is required for IDC login") + } + + region := promptInput("? Enter Region", defaultIDCRegion) + + return c.LoginWithIDC(ctx, startURL, region) +} + // RegisterClient registers a new OIDC client with AWS. func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { payload := map[string]interface{}{ @@ -211,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, } if json.Unmarshal(respBody, &errResp) == nil { if errResp.Error == "authorization_pending" { - return nil, fmt.Errorf("authorization_pending") + return nil, ErrAuthorizationPending } if errResp.Error == "slow_down" { - return nil, fmt.Errorf("slow_down") + return nil, ErrSlowDown } } log.Debugf("create token failed: %s", string(respBody)) @@ -359,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, case <-time.After(interval): tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "authorization_pending") { + if errors.Is(err, ErrAuthorizationPending) { fmt.Print(".") continue } - if strings.Contains(errStr, "slow_down") { + if errors.Is(err, ErrSlowDown) { interval += 5 * time.Second continue } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 1da7f25ba..1e882888d 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -10,6 +10,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "time" @@ -43,10 +45,15 @@ const ( // Event Stream error type constants ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed - // kiroUserAgent matches amq2api format for User-Agent header + // kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style) kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" - // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api + // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style) kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" + + // Kiro IDE style headers (from kiro2api - for IDC auth) + kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" + kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" + kiroIDEAgentModeSpec = "spec" ) // Real-time usage estimation configuration @@ -101,11 +108,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{ // getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. // Supports reordering based on "preferred_endpoint" in auth metadata/attributes. +// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin. func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { if auth == nil { return kiroEndpointConfigs } + // For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth) + // Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth + // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC) + // NOT in how API calls are made - both Social and IDC use the same endpoint/origin + if auth.Metadata != nil { + authMethod, _ := auth.Metadata["auth_method"].(string) + if authMethod == "idc" { + log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint") + return kiroEndpointConfigs + } + } + // Check for preference var preference string if auth.Metadata != nil { @@ -162,6 +182,15 @@ type KiroExecutor struct { refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions } +// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. +func isIDCAuth(auth *cliproxyauth.Auth) bool { + if auth == nil || auth.Metadata == nil { + return false + } + authMethod, _ := auth.Metadata["auth_method"].(string) + return authMethod == "idc" +} + // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. // This is critical because OpenAI and Claude formats have different tool structures: // - OpenAI: tools[].function.name, tools[].function.description @@ -210,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before request") } @@ -262,15 +295,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + + // Use different headers based on auth type + // IDC auth uses Kiro IDE style headers (from kiro2api) + // Other auth types use Amazon Q CLI style headers + if isIDCAuth(auth) { + httpReq.Header.Set("User-Agent", kiroIDEUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) + log.Debugf("kiro: using Kiro IDE headers for IDC auth") + } else { + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -358,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -416,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying request") @@ -518,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before stream request") } @@ -568,15 +628,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + + // Use different headers based on auth type + // IDC auth uses Kiro IDE style headers (from kiro2api) + // Other auth types use Amazon Q CLI style headers + if isIDCAuth(auth) { + httpReq.Header.Set("User-Agent", kiroIDEUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) + log.Debugf("kiro: using Kiro IDE headers for IDC auth") + } else { + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -677,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -735,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying stream request") @@ -1001,12 +1084,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { // This consolidates the auth_method check that was previously done separately. func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { if auth != nil && auth.Metadata != nil { - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { - // builder-id auth doesn't need profileArn + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { + // builder-id and idc auth don't need profileArn return "" } } - // For non-builder-id auth (social auth), profileArn is required + // For non-builder-id/idc auth (social auth), profileArn is required if profileArn == "" { log.Warnf("kiro: profile ARN not found in auth, API calls may fail") } @@ -3010,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var refreshToken string var clientID, clientSecret string var authMethod string + var region, startURL string if auth.Metadata != nil { if rt, ok := auth.Metadata["refresh_token"].(string); ok { @@ -3024,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c if am, ok := auth.Metadata["auth_method"].(string); ok { authMethod = am } + if r, ok := auth.Metadata["region"].(string); ok { + region = r + } + if su, ok := auth.Metadata["start_url"].(string); ok { + startURL = su + } } if refreshToken == "" { @@ -3033,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") - ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: + // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub) log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") oauth := kiroauth.NewKiroOAuth(e.cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken) @@ -3094,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c return updated, nil } +// persistRefreshedAuth persists a refreshed auth record to disk. +// This ensures token refreshes from inline retry are saved to the auth file. +func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") + } + + // Determine the file path from auth attributes or filename + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return fmt.Errorf("kiro executor: auth has no file path or filename") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return fmt.Errorf("kiro executor: cannot determine auth file path") + } + } + + // Marshal metadata to JSON + raw, err := json.Marshal(auth.Metadata) + if err != nil { + return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) + } + + // Write to temp file first, then rename (atomic write) + tmp := authPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) + } + if err := os.Rename(tmp, authPath); err != nil { + return fmt.Errorf("kiro executor: rename auth file failed: %w", err) + } + + log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) + return nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool { diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index b937152d8..b75cd28ef 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -53,20 +53,8 @@ func (a *KiroAuthenticator) RefreshLead() *time.Duration { return &d } -// Login performs OAuth login for Kiro with AWS Builder ID. -func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use AWS Builder ID device code flow - tokenData, err := oauth.LoginWithBuilderID(ctx) - if err != nil { - return nil, fmt.Errorf("login failed: %w", err) - } - +// createAuthRecord creates an auth record from token data. +func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) { // Parse expires_at expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) if err != nil { @@ -76,34 +64,63 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts // Extract identifier for file naming idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Determine label based on auth method + label := fmt.Sprintf("kiro-%s", source) + if tokenData.AuthMethod == "idc" { + label = "kiro-idc" + } + now := time.Now() - fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + fileName := fmt.Sprintf("%s-%s.json", label, idPart) + + metadata := map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + } + + // Add IDC-specific fields if present + if tokenData.StartURL != "" { + metadata["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + metadata["region"] = tokenData.Region + } + + attributes := map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": source, + "email": tokenData.Email, + } + + // Add IDC-specific attributes if present + if tokenData.AuthMethod == "idc" { + attributes["source"] = "aws-idc" + if tokenData.StartURL != "" { + attributes["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + attributes["region"] = tokenData.Region + } + } record := &coreauth.Auth{ ID: fileName, Provider: "kiro", FileName: fileName, - Label: "kiro-aws", + Label: label, Status: coreauth.StatusActive, CreatedAt: now, UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "client_id": tokenData.ClientID, - "client_secret": tokenData.ClientSecret, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "aws-builder-id", - "email": tokenData.Email, - }, + Metadata: metadata, + Attributes: attributes, // NextRefreshAfter is aligned with RefreshLead (5min) NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } @@ -117,6 +134,23 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts return record, nil } +// Login performs OAuth login for Kiro with AWS (Builder ID or IDC). +// This shows a method selection prompt and handles both flows. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + // Use the unified method selection flow (Builder ID or IDC) + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err := ssoClient.LoginWithMethodSelection(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + return a.createAuthRecord(tokenData, "aws") +} + // LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow. // This provides a better UX than device code flow as it uses automatic browser callback. func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { @@ -388,15 +422,23 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut clientID, _ := auth.Metadata["client_id"].(string) clientSecret, _ := auth.Metadata["client_secret"].(string) authMethod, _ := auth.Metadata["auth_method"].(string) + startURL, _ := auth.Metadata["start_url"].(string) + region, _ := auth.Metadata["region"].(string) var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { - ssoClient := kiroauth.NewSSOOIDCClient(cfg) + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) oauth := kiroauth.NewKiroOAuth(cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken)