From 7fd98f3556ea4b48cf0ebd2e4497d12008c53f88 Mon Sep 17 00:00:00 2001 From: Joao Date: Sun, 21 Dec 2025 14:49:19 +0000 Subject: [PATCH 1/3] feat: add IDC auth support with Kiro IDE headers --- internal/auth/kiro/aws.go | 4 + internal/auth/kiro/cognito.go | 408 +++++++++++++++++++ internal/auth/kiro/sso_oidc.go | 432 ++++++++++++++++++++- internal/runtime/executor/kiro_executor.go | 203 +++++++++- sdk/auth/kiro.go | 116 ++++-- 5 files changed, 1113 insertions(+), 50 deletions(-) create mode 100644 internal/auth/kiro/cognito.go diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go index 9be025c29..ba73af4dd 100644 --- a/internal/auth/kiro/aws.go +++ b/internal/auth/kiro/aws.go @@ -40,6 +40,10 @@ type KiroTokenData struct { ClientSecret string `json:"clientSecret,omitempty"` // Email is the user's email address (used for file naming) Email string `json:"email,omitempty"` + // StartURL is the IDC/Identity Center start URL (only for IDC auth method) + StartURL string `json:"startUrl,omitempty"` + // Region is the AWS region for IDC authentication (only for IDC auth method) + Region string `json:"region,omitempty"` } // KiroAuthBundle aggregates authentication data after OAuth flow completion diff --git a/internal/auth/kiro/cognito.go b/internal/auth/kiro/cognito.go new file mode 100644 index 000000000..7cf328186 --- /dev/null +++ b/internal/auth/kiro/cognito.go @@ -0,0 +1,408 @@ +// Package kiro provides Cognito Identity credential exchange for IDC authentication. +// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials +// instead of Bearer token authentication. +package kiro + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // Cognito Identity endpoints + cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com" + + // Identity Pool ID for Q Developer / CodeWhisperer + // This is the identity pool used by kiro-cli and Amazon Q CLI + cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5" + + // Cognito provider name for SSO OIDC + cognitoProviderName = "cognito-identity.amazonaws.com" +) + +// CognitoCredentials holds temporary AWS credentials from Cognito Identity. +type CognitoCredentials struct { + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + SessionToken string `json:"session_token"` + Expiration time.Time `json:"expiration"` +} + +// CognitoIdentityClient handles Cognito Identity credential exchange. +type CognitoIdentityClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewCognitoIdentityClient creates a new Cognito Identity client. +func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &CognitoIdentityClient{ + httpClient: client, + cfg: cfg, + } +} + +// GetIdentityID retrieves a Cognito Identity ID using the SSO access token. +func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) { + if region == "" { + region = "us-east-1" + } + + endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) + + // Build the GetId request + // The SSO token is passed as a login token for the identity pool + payload := map[string]interface{}{ + "IdentityPoolId": cognitoIdentityPoolID, + "Logins": map[string]string{ + // Use the OIDC provider URL as the key + fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, + }, + } + + body, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal GetId request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) + if err != nil { + return "", fmt.Errorf("failed to create GetId request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.1") + req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("GetId request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read GetId response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody)) + return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var result struct { + IdentityID string `json:"IdentityId"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("failed to parse GetId response: %w", err) + } + + if result.IdentityID == "" { + return "", fmt.Errorf("empty IdentityId in GetId response") + } + + log.Debugf("Cognito Identity ID: %s", result.IdentityID) + return result.IdentityID, nil +} + +// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials. +func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) { + if region == "" { + region = "us-east-1" + } + + endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "IdentityId": identityID, + "Logins": map[string]string{ + fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, + }, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.1") + req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Credentials struct { + AccessKeyID string `json:"AccessKeyId"` + SecretKey string `json:"SecretKey"` + SessionToken string `json:"SessionToken"` + Expiration int64 `json:"Expiration"` + } `json:"Credentials"` + IdentityID string `json:"IdentityId"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err) + } + + if result.Credentials.AccessKeyID == "" { + return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response") + } + + // Expiration is in seconds since epoch + expiration := time.Unix(result.Credentials.Expiration, 0) + + log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339)) + + return &CognitoCredentials{ + AccessKeyID: result.Credentials.AccessKeyID, + SecretAccessKey: result.Credentials.SecretKey, + SessionToken: result.Credentials.SessionToken, + Expiration: expiration, + }, nil +} + +// ExchangeSSOTokenForCredentials is a convenience method that performs the full +// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity +func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) { + log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region) + + // Step 1: Get Identity ID + identityID, err := c.GetIdentityID(ctx, accessToken, region) + if err != nil { + return nil, fmt.Errorf("failed to get identity ID: %w", err) + } + + // Step 2: Get credentials for the identity + creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region) + if err != nil { + return nil, fmt.Errorf("failed to get credentials for identity: %w", err) + } + + return creds, nil +} + +// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests. +type SigV4Signer struct { + credentials *CognitoCredentials + region string + service string +} + +// NewSigV4Signer creates a new SigV4 signer with the given credentials. +func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer { + return &SigV4Signer{ + credentials: creds, + region: region, + service: service, + } +} + +// SignRequest signs an HTTP request using AWS Signature Version 4. +// The request body must be provided separately since it may have been read already. +func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error { + now := time.Now().UTC() + amzDate := now.Format("20060102T150405Z") + dateStamp := now.Format("20060102") + + // Ensure required headers are set + if req.Header.Get("Host") == "" { + req.Header.Set("Host", req.URL.Host) + } + req.Header.Set("X-Amz-Date", amzDate) + if s.credentials.SessionToken != "" { + req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken) + } + + // Create canonical request + canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body) + + // Create string to sign + algorithm := "AWS4-HMAC-SHA256" + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service) + stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", + algorithm, + amzDate, + credentialScope, + hashSHA256([]byte(canonicalRequest)), + ) + + // Calculate signature + signingKey := s.getSignatureKey(dateStamp) + signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) + + // Build Authorization header + authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", + algorithm, + s.credentials.AccessKeyID, + credentialScope, + signedHeaders, + signature, + ) + + req.Header.Set("Authorization", authHeader) + + return nil +} + +// createCanonicalRequest builds the canonical request string for SigV4. +func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) { + // HTTP method + method := req.Method + + // Canonical URI + uri := req.URL.Path + if uri == "" { + uri = "/" + } + + // Canonical query string (sorted) + queryString := s.buildCanonicalQueryString(req) + + // Canonical headers (sorted, lowercase) + canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req) + + // Hashed payload + payloadHash := hashSHA256(body) + + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + method, + uri, + queryString, + canonicalHeaders, + signedHeaders, + payloadHash, + ) + + return canonicalRequest, signedHeaders +} + +// buildCanonicalQueryString builds a sorted, URI-encoded query string. +func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string { + if req.URL.RawQuery == "" { + return "" + } + + // Parse and sort query parameters + params := make([]string, 0) + for key, values := range req.URL.Query() { + for _, value := range values { + params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value))) + } + } + sort.Strings(params) + return strings.Join(params, "&") +} + +// buildCanonicalHeaders builds sorted, lowercase canonical headers. +func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) { + // Headers to sign (must include host and x-amz-*) + headerMap := make(map[string]string) + headerMap["host"] = req.URL.Host + + for key, values := range req.Header { + lowKey := strings.ToLower(key) + // Include x-amz-* headers and content-type + if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" { + headerMap[lowKey] = strings.TrimSpace(values[0]) + } + } + + // Sort header names + headerNames := make([]string, 0, len(headerMap)) + for name := range headerMap { + headerNames = append(headerNames, name) + } + sort.Strings(headerNames) + + // Build canonical headers and signed headers + var canonicalHeaders strings.Builder + for _, name := range headerNames { + canonicalHeaders.WriteString(name) + canonicalHeaders.WriteString(":") + canonicalHeaders.WriteString(headerMap[name]) + canonicalHeaders.WriteString("\n") + } + + signedHeaders := strings.Join(headerNames, ";") + + return canonicalHeaders.String(), signedHeaders +} + +// getSignatureKey derives the signing key for SigV4. +func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte { + kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp)) + kRegion := hmacSHA256(kDate, []byte(s.region)) + kService := hmacSHA256(kRegion, []byte(s.service)) + kSigning := hmacSHA256(kService, []byte("aws4_request")) + return kSigning +} + +// hmacSHA256 computes HMAC-SHA256. +func hmacSHA256(key, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} + +// hashSHA256 computes SHA256 hash and returns hex string. +func hashSHA256(data []byte) string { + hash := sha256.Sum256(data) + return hex.EncodeToString(hash[:]) +} + +// uriEncode performs URI encoding for SigV4. +func uriEncode(s string) string { + var result strings.Builder + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' { + result.WriteByte(c) + } else { + result.WriteString(fmt.Sprintf("%%%02X", c)) + } + } + return result.String() +} + +// IsExpired checks if the credentials are expired or about to expire. +func (c *CognitoCredentials) IsExpired() bool { + // Consider expired if within 5 minutes of expiration + return time.Now().Add(5 * time.Minute).After(c.Expiration) +} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 2c9150f1b..6ef2e9605 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -2,6 +2,7 @@ package kiro import ( + "bufio" "context" "crypto/rand" "crypto/sha256" @@ -12,6 +13,7 @@ import ( "io" "net" "net/http" + "os" "strings" "time" @@ -24,10 +26,13 @@ import ( const ( // AWS SSO OIDC endpoints ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" - + // Kiro's start URL for Builder ID builderIDStartURL = "https://view.awsapps.com/start" - + + // Default region for IDC + defaultIDCRegion = "us-east-1" + // Polling interval pollInterval = 5 * time.Second @@ -83,6 +88,429 @@ type CreateTokenResponse struct { RefreshToken string `json:"refreshToken"` } +// getOIDCEndpoint returns the OIDC endpoint for the given region. +func getOIDCEndpoint(region string) string { + if region == "" { + region = defaultIDCRegion + } + return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) +} + +// promptInput prompts the user for input with an optional default value. +func promptInput(prompt, defaultValue string) string { + reader := bufio.NewReader(os.Stdin) + if defaultValue != "" { + fmt.Printf("%s [%s]: ", prompt, defaultValue) + } else { + fmt.Printf("%s: ", prompt) + } + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + if input == "" { + return defaultValue + } + return input +} + +// promptSelect prompts the user to select from options using arrow keys or number input. +func promptSelect(prompt string, options []string) int { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Print("Enter selection (1-", len(options), "): ") + + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + return 0 // Default to first option + } + return selection - 1 +} + +// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. +func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. +func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": startURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateTokenWithRegion polls for the access token after user authorization using a specific region. +func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, fmt.Errorf("authorization_pending") + } + if errResp.Error == "slow_down" { + return nil, fmt.Errorf("slow_down") + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region. +func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + + // Set headers matching kiro2api's IDC token refresh + // These headers are required for successful IDC token refresh + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) + req.Header.Set("Connection", "keep-alive") + req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE") + req.Header.Set("Accept", "*/*") + req.Header.Set("Accept-Language", "*") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("User-Agent", "node") + req.Header.Set("Accept-Encoding", "br, gzip, deflate") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + StartURL: startURL, + Region: region, + }, nil +} + +// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC). +func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Identity Center) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client with the specified region + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClientWithRegion(ctx, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization with IDC start URL + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Confirm the following code in the browser:\n") + fmt.Printf(" Code: %s\n", authResp.UserCode) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete) + + // Set incognito mode based on config + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + fmt.Print(".") + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Fetch user email + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + StartURL: startURL, + Region: region, + }, nil + } + } + + // Close browser on timeout + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. +func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Prompt for login method + options := []string{ + "Use with Builder ID (personal AWS account)", + "Use with IDC Account (organization SSO)", + } + selection := promptSelect("\n? Select login method:", options) + + if selection == 0 { + // Builder ID flow - use existing implementation + return c.LoginWithBuilderID(ctx) + } + + // IDC flow - prompt for start URL and region + fmt.Println() + startURL := promptInput("? Enter Start URL", "") + if startURL == "" { + return nil, fmt.Errorf("start URL is required for IDC login") + } + + region := promptInput("? Enter Region", defaultIDCRegion) + + return c.LoginWithIDC(ctx, startURL, region) +} + // RegisterClient registers a new OIDC client with AWS. func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { payload := map[string]interface{}{ diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 1da7f25ba..70f23dfb9 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -43,10 +43,15 @@ const ( // Event Stream error type constants ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed - // kiroUserAgent matches amq2api format for User-Agent header + // kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style) kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" - // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api + // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style) kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" + + // Kiro IDE style headers (from kiro2api - for IDC auth) + kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" + kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" + kiroIDEAgentModeSpec = "spec" ) // Real-time usage estimation configuration @@ -101,11 +106,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{ // getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. // Supports reordering based on "preferred_endpoint" in auth metadata/attributes. +// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin. func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { if auth == nil { return kiroEndpointConfigs } + // For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth) + // Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth + // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC) + // NOT in how API calls are made - both Social and IDC use the same endpoint/origin + if auth.Metadata != nil { + authMethod, _ := auth.Metadata["auth_method"].(string) + if authMethod == "idc" { + log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint") + return kiroEndpointConfigs + } + } + // Check for preference var preference string if auth.Metadata != nil { @@ -160,6 +178,79 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { type KiroExecutor struct { cfg *config.Config refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions + + // cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication + // Key: auth.ID, Value: *kiroauth.CognitoCredentials + cognitoCredsCache sync.Map +} + +// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid. +func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials { + if cached, ok := e.cognitoCredsCache.Load(authID); ok { + creds := cached.(*kiroauth.CognitoCredentials) + if !creds.IsExpired() { + return creds + } + // Credentials expired, remove from cache + e.cognitoCredsCache.Delete(authID) + } + return nil +} + +// cacheCognitoCredentials stores Cognito credentials in the cache. +func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) { + e.cognitoCredsCache.Store(authID, creds) +} + +// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones. +func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) { + if auth == nil { + return nil, fmt.Errorf("auth is nil") + } + + // Check cache first + if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil { + log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339)) + return creds, nil + } + + // Get region from auth metadata + region := "us-east-1" + if auth.Metadata != nil { + if r, ok := auth.Metadata["region"].(string); ok && r != "" { + region = r + } + } + + log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region) + + // Exchange SSO token for Cognito credentials + cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg) + creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region) + if err != nil { + return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err) + } + + // Cache the credentials + e.cacheCognitoCredentials(auth.ID, creds) + log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339)) + + return creds, nil +} + +// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. +func isIDCAuth(auth *cliproxyauth.Auth) bool { + if auth == nil || auth.Metadata == nil { + return false + } + authMethod, _ := auth.Metadata["auth_method"].(string) + return authMethod == "idc" +} + +// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials. +func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error { + signer := kiroauth.NewSigV4Signer(creds, region, service) + return signer.SignRequest(req, payload) } // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. @@ -262,15 +353,60 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + + // Use different headers based on auth type + // IDC auth uses Kiro IDE style headers (from kiro2api) + // Other auth types use Amazon Q CLI style headers + if isIDCAuth(auth) { + httpReq.Header.Set("User-Agent", kiroIDEUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) + log.Debugf("kiro: using Kiro IDE headers for IDC auth") + } else { + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + // Choose auth method: SigV4 for IDC, Bearer token for others + // NOTE: Cognito credential exchange disabled for now - testing Bearer token first + if false && isIDCAuth(auth) { + // IDC auth requires SigV4 signing with Cognito-exchanged credentials + cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) + if err != nil { + log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) + return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) + } + + // Get region from auth metadata + region := "us-east-1" + if auth.Metadata != nil { + if r, ok := auth.Metadata["region"].(string); ok && r != "" { + region = r + } + } + + // Determine service from URL + service := "codewhisperer" + if strings.Contains(url, "q.us-east-1.amazonaws.com") { + service = "qdeveloper" + } + + // Sign the request with SigV4 + if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { + log.Warnf("kiro: failed to sign request with SigV4: %v", err) + return resp, fmt.Errorf("SigV4 signing failed: %w", err) + } + log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) + } else { + // Standard Bearer token authentication for Builder ID, social auth, etc. + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + } + var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -568,15 +704,60 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } httpReq.Header.Set("Content-Type", kiroContentType) - httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Accept", kiroAcceptStream) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + + // Use different headers based on auth type + // IDC auth uses Kiro IDE style headers (from kiro2api) + // Other auth types use Amazon Q CLI style headers + if isIDCAuth(auth) { + httpReq.Header.Set("User-Agent", kiroIDEUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) + log.Debugf("kiro: using Kiro IDE headers for IDC auth") + } else { + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + // Choose auth method: SigV4 for IDC, Bearer token for others + // NOTE: Cognito credential exchange disabled for now - testing Bearer token first + if false && isIDCAuth(auth) { + // IDC auth requires SigV4 signing with Cognito-exchanged credentials + cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) + if err != nil { + log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) + return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) + } + + // Get region from auth metadata + region := "us-east-1" + if auth.Metadata != nil { + if r, ok := auth.Metadata["region"].(string); ok && r != "" { + region = r + } + } + + // Determine service from URL + service := "codewhisperer" + if strings.Contains(url, "q.us-east-1.amazonaws.com") { + service = "qdeveloper" + } + + // Sign the request with SigV4 + if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { + log.Warnf("kiro: failed to sign request with SigV4: %v", err) + return nil, fmt.Errorf("SigV4 signing failed: %w", err) + } + log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) + } else { + // Standard Bearer token authentication for Builder ID, social auth, etc. + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + } + var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -1001,12 +1182,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { // This consolidates the auth_method check that was previously done separately. func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { if auth != nil && auth.Metadata != nil { - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { - // builder-id auth doesn't need profileArn + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { + // builder-id and idc auth don't need profileArn return "" } } - // For non-builder-id auth (social auth), profileArn is required + // For non-builder-id/idc auth (social auth), profileArn is required if profileArn == "" { log.Warnf("kiro: profile ARN not found in auth, API calls may fail") } diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index b937152d8..b75cd28ef 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -53,20 +53,8 @@ func (a *KiroAuthenticator) RefreshLead() *time.Duration { return &d } -// Login performs OAuth login for Kiro with AWS Builder ID. -func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use AWS Builder ID device code flow - tokenData, err := oauth.LoginWithBuilderID(ctx) - if err != nil { - return nil, fmt.Errorf("login failed: %w", err) - } - +// createAuthRecord creates an auth record from token data. +func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) { // Parse expires_at expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) if err != nil { @@ -76,34 +64,63 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts // Extract identifier for file naming idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Determine label based on auth method + label := fmt.Sprintf("kiro-%s", source) + if tokenData.AuthMethod == "idc" { + label = "kiro-idc" + } + now := time.Now() - fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + fileName := fmt.Sprintf("%s-%s.json", label, idPart) + + metadata := map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + } + + // Add IDC-specific fields if present + if tokenData.StartURL != "" { + metadata["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + metadata["region"] = tokenData.Region + } + + attributes := map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": source, + "email": tokenData.Email, + } + + // Add IDC-specific attributes if present + if tokenData.AuthMethod == "idc" { + attributes["source"] = "aws-idc" + if tokenData.StartURL != "" { + attributes["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + attributes["region"] = tokenData.Region + } + } record := &coreauth.Auth{ ID: fileName, Provider: "kiro", FileName: fileName, - Label: "kiro-aws", + Label: label, Status: coreauth.StatusActive, CreatedAt: now, UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "client_id": tokenData.ClientID, - "client_secret": tokenData.ClientSecret, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "aws-builder-id", - "email": tokenData.Email, - }, + Metadata: metadata, + Attributes: attributes, // NextRefreshAfter is aligned with RefreshLead (5min) NextRefreshAfter: expiresAt.Add(-5 * time.Minute), } @@ -117,6 +134,23 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts return record, nil } +// Login performs OAuth login for Kiro with AWS (Builder ID or IDC). +// This shows a method selection prompt and handles both flows. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + // Use the unified method selection flow (Builder ID or IDC) + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err := ssoClient.LoginWithMethodSelection(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + return a.createAuthRecord(tokenData, "aws") +} + // LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow. // This provides a better UX than device code flow as it uses automatic browser callback. func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { @@ -388,15 +422,23 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut clientID, _ := auth.Metadata["client_id"].(string) clientSecret, _ := auth.Metadata["client_secret"].(string) authMethod, _ := auth.Metadata["auth_method"].(string) + startURL, _ := auth.Metadata["start_url"].(string) + region, _ := auth.Metadata["region"].(string) var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { - ssoClient := kiroauth.NewSSOOIDCClient(cfg) + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) oauth := kiroauth.NewKiroOAuth(cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken) From 98db5aabd0591b19b444db0f009816d86c76a932 Mon Sep 17 00:00:00 2001 From: Joao Date: Mon, 22 Dec 2025 12:23:10 +0000 Subject: [PATCH 2/3] feat: persist refreshed IDC tokens to auth file Add persistRefreshedAuth function to write refreshed tokens back to the auth file after inline token refresh. This prevents repeated token refreshes on every request when the token expires. Changes: - Add persistRefreshedAuth() to kiro_executor.go - Call persist after all token refresh paths (401, 403, pre-request) - Remove unused log import from sdk/auth/kiro.go --- internal/auth/kiro/cognito.go | 408 --------------------- internal/auth/kiro/sso_oidc.go | 2 +- internal/runtime/executor/kiro_executor.go | 236 +++++------- 3 files changed, 101 insertions(+), 545 deletions(-) delete mode 100644 internal/auth/kiro/cognito.go diff --git a/internal/auth/kiro/cognito.go b/internal/auth/kiro/cognito.go deleted file mode 100644 index 7cf328186..000000000 --- a/internal/auth/kiro/cognito.go +++ /dev/null @@ -1,408 +0,0 @@ -// Package kiro provides Cognito Identity credential exchange for IDC authentication. -// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials -// instead of Bearer token authentication. -package kiro - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "net/http" - "sort" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -const ( - // Cognito Identity endpoints - cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com" - - // Identity Pool ID for Q Developer / CodeWhisperer - // This is the identity pool used by kiro-cli and Amazon Q CLI - cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5" - - // Cognito provider name for SSO OIDC - cognitoProviderName = "cognito-identity.amazonaws.com" -) - -// CognitoCredentials holds temporary AWS credentials from Cognito Identity. -type CognitoCredentials struct { - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - SessionToken string `json:"session_token"` - Expiration time.Time `json:"expiration"` -} - -// CognitoIdentityClient handles Cognito Identity credential exchange. -type CognitoIdentityClient struct { - httpClient *http.Client - cfg *config.Config -} - -// NewCognitoIdentityClient creates a new Cognito Identity client. -func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient { - client := &http.Client{Timeout: 30 * time.Second} - if cfg != nil { - client = util.SetProxy(&cfg.SDKConfig, client) - } - return &CognitoIdentityClient{ - httpClient: client, - cfg: cfg, - } -} - -// GetIdentityID retrieves a Cognito Identity ID using the SSO access token. -func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - // Build the GetId request - // The SSO token is passed as a login token for the identity pool - payload := map[string]interface{}{ - "IdentityPoolId": cognitoIdentityPoolID, - "Logins": map[string]string{ - // Use the OIDC provider URL as the key - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal GetId request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return "", fmt.Errorf("failed to create GetId request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("GetId request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read GetId response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("failed to parse GetId response: %w", err) - } - - if result.IdentityID == "" { - return "", fmt.Errorf("empty IdentityId in GetId response") - } - - log.Debugf("Cognito Identity ID: %s", result.IdentityID) - return result.IdentityID, nil -} - -// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials. -func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "IdentityId": identityID, - "Logins": map[string]string{ - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - Credentials struct { - AccessKeyID string `json:"AccessKeyId"` - SecretKey string `json:"SecretKey"` - SessionToken string `json:"SessionToken"` - Expiration int64 `json:"Expiration"` - } `json:"Credentials"` - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err) - } - - if result.Credentials.AccessKeyID == "" { - return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response") - } - - // Expiration is in seconds since epoch - expiration := time.Unix(result.Credentials.Expiration, 0) - - log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339)) - - return &CognitoCredentials{ - AccessKeyID: result.Credentials.AccessKeyID, - SecretAccessKey: result.Credentials.SecretKey, - SessionToken: result.Credentials.SessionToken, - Expiration: expiration, - }, nil -} - -// ExchangeSSOTokenForCredentials is a convenience method that performs the full -// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity -func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) { - log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region) - - // Step 1: Get Identity ID - identityID, err := c.GetIdentityID(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get identity ID: %w", err) - } - - // Step 2: Get credentials for the identity - creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get credentials for identity: %w", err) - } - - return creds, nil -} - -// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests. -type SigV4Signer struct { - credentials *CognitoCredentials - region string - service string -} - -// NewSigV4Signer creates a new SigV4 signer with the given credentials. -func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer { - return &SigV4Signer{ - credentials: creds, - region: region, - service: service, - } -} - -// SignRequest signs an HTTP request using AWS Signature Version 4. -// The request body must be provided separately since it may have been read already. -func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error { - now := time.Now().UTC() - amzDate := now.Format("20060102T150405Z") - dateStamp := now.Format("20060102") - - // Ensure required headers are set - if req.Header.Get("Host") == "" { - req.Header.Set("Host", req.URL.Host) - } - req.Header.Set("X-Amz-Date", amzDate) - if s.credentials.SessionToken != "" { - req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken) - } - - // Create canonical request - canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body) - - // Create string to sign - algorithm := "AWS4-HMAC-SHA256" - credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service) - stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", - algorithm, - amzDate, - credentialScope, - hashSHA256([]byte(canonicalRequest)), - ) - - // Calculate signature - signingKey := s.getSignatureKey(dateStamp) - signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - - // Build Authorization header - authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", - algorithm, - s.credentials.AccessKeyID, - credentialScope, - signedHeaders, - signature, - ) - - req.Header.Set("Authorization", authHeader) - - return nil -} - -// createCanonicalRequest builds the canonical request string for SigV4. -func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) { - // HTTP method - method := req.Method - - // Canonical URI - uri := req.URL.Path - if uri == "" { - uri = "/" - } - - // Canonical query string (sorted) - queryString := s.buildCanonicalQueryString(req) - - // Canonical headers (sorted, lowercase) - canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req) - - // Hashed payload - payloadHash := hashSHA256(body) - - canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", - method, - uri, - queryString, - canonicalHeaders, - signedHeaders, - payloadHash, - ) - - return canonicalRequest, signedHeaders -} - -// buildCanonicalQueryString builds a sorted, URI-encoded query string. -func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string { - if req.URL.RawQuery == "" { - return "" - } - - // Parse and sort query parameters - params := make([]string, 0) - for key, values := range req.URL.Query() { - for _, value := range values { - params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value))) - } - } - sort.Strings(params) - return strings.Join(params, "&") -} - -// buildCanonicalHeaders builds sorted, lowercase canonical headers. -func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) { - // Headers to sign (must include host and x-amz-*) - headerMap := make(map[string]string) - headerMap["host"] = req.URL.Host - - for key, values := range req.Header { - lowKey := strings.ToLower(key) - // Include x-amz-* headers and content-type - if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" { - headerMap[lowKey] = strings.TrimSpace(values[0]) - } - } - - // Sort header names - headerNames := make([]string, 0, len(headerMap)) - for name := range headerMap { - headerNames = append(headerNames, name) - } - sort.Strings(headerNames) - - // Build canonical headers and signed headers - var canonicalHeaders strings.Builder - for _, name := range headerNames { - canonicalHeaders.WriteString(name) - canonicalHeaders.WriteString(":") - canonicalHeaders.WriteString(headerMap[name]) - canonicalHeaders.WriteString("\n") - } - - signedHeaders := strings.Join(headerNames, ";") - - return canonicalHeaders.String(), signedHeaders -} - -// getSignatureKey derives the signing key for SigV4. -func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte { - kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp)) - kRegion := hmacSHA256(kDate, []byte(s.region)) - kService := hmacSHA256(kRegion, []byte(s.service)) - kSigning := hmacSHA256(kService, []byte("aws4_request")) - return kSigning -} - -// hmacSHA256 computes HMAC-SHA256. -func hmacSHA256(key, data []byte) []byte { - h := hmac.New(sha256.New, key) - h.Write(data) - return h.Sum(nil) -} - -// hashSHA256 computes SHA256 hash and returns hex string. -func hashSHA256(data []byte) string { - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) -} - -// uriEncode performs URI encoding for SigV4. -func uriEncode(s string) string { - var result strings.Builder - for i := 0; i < len(s); i++ { - c := s[i] - if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' { - result.WriteByte(c) - } else { - result.WriteString(fmt.Sprintf("%%%02X", c)) - } - } - return result.String() -} - -// IsExpired checks if the credentials are expired or about to expire. -func (c *CognitoCredentials) IsExpired() bool { - // Consider expired if within 5 minutes of expiration - return time.Now().Add(5 * time.Minute).After(c.Expiration) -} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 6ef2e9605..292f5bcff 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -334,7 +334,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl } if resp.StatusCode != http.StatusOK { - log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 70f23dfb9..1e882888d 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -10,6 +10,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "time" @@ -178,64 +180,6 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { type KiroExecutor struct { cfg *config.Config refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions - - // cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication - // Key: auth.ID, Value: *kiroauth.CognitoCredentials - cognitoCredsCache sync.Map -} - -// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid. -func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials { - if cached, ok := e.cognitoCredsCache.Load(authID); ok { - creds := cached.(*kiroauth.CognitoCredentials) - if !creds.IsExpired() { - return creds - } - // Credentials expired, remove from cache - e.cognitoCredsCache.Delete(authID) - } - return nil -} - -// cacheCognitoCredentials stores Cognito credentials in the cache. -func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) { - e.cognitoCredsCache.Store(authID, creds) -} - -// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones. -func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) { - if auth == nil { - return nil, fmt.Errorf("auth is nil") - } - - // Check cache first - if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil { - log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339)) - return creds, nil - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region) - - // Exchange SSO token for Cognito credentials - cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg) - creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err) - } - - // Cache the credentials - e.cacheCognitoCredentials(auth.ID, creds) - log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339)) - - return creds, nil } // isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. @@ -247,12 +191,6 @@ func isIDCAuth(auth *cliproxyauth.Auth) bool { return authMethod == "idc" } -// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials. -func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error { - signer := kiroauth.NewSigV4Signer(creds, region, service) - return signer.SignRequest(req, payload) -} - // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. // This is critical because OpenAI and Claude formats have different tool structures: // - OpenAI: tools[].function.name, tools[].function.description @@ -301,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before request") } @@ -372,40 +314,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return resp, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -494,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -552,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying request") @@ -654,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before stream request") } @@ -723,40 +647,8 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return nil, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -858,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -916,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying stream request") @@ -3191,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var refreshToken string var clientID, clientSecret string var authMethod string + var region, startURL string if auth.Metadata != nil { if rt, ok := auth.Metadata["refresh_token"].(string); ok { @@ -3205,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c if am, ok := auth.Metadata["auth_method"].(string); ok { authMethod = am } + if r, ok := auth.Metadata["region"].(string); ok { + region = r + } + if su, ok := auth.Metadata["start_url"].(string); ok { + startURL = su + } } if refreshToken == "" { @@ -3214,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") - ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: + // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub) log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") oauth := kiroauth.NewKiroOAuth(e.cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken) @@ -3275,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c return updated, nil } +// persistRefreshedAuth persists a refreshed auth record to disk. +// This ensures token refreshes from inline retry are saved to the auth file. +func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") + } + + // Determine the file path from auth attributes or filename + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return fmt.Errorf("kiro executor: auth has no file path or filename") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return fmt.Errorf("kiro executor: cannot determine auth file path") + } + } + + // Marshal metadata to JSON + raw, err := json.Marshal(auth.Metadata) + if err != nil { + return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) + } + + // Write to temp file first, then rename (atomic write) + tmp := authPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) + } + if err := os.Rename(tmp, authPath); err != nil { + return fmt.Errorf("kiro executor: rename auth file failed: %w", err) + } + + log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) + return nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool { From 349b2ba3afa42ad7f5374fe72665e34f01e7206e Mon Sep 17 00:00:00 2001 From: Joao Date: Tue, 23 Dec 2025 10:20:14 +0000 Subject: [PATCH 3/3] refactor: improve error handling and code quality - Handle errors in promptInput instead of ignoring them - Improve promptSelect to provide feedback on invalid input and re-prompt - Use sentinel errors (ErrAuthorizationPending, ErrSlowDown) instead of string-based error checking with strings.Contains - Move hardcoded x-amz-user-agent header to idcAmzUserAgent constant Addresses code review feedback from Gemini Code Assist. --- internal/auth/kiro/sso_oidc.go | 76 +++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 292f5bcff..ab44e55f6 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "html" "io" @@ -35,13 +36,22 @@ const ( // Polling interval pollInterval = 5 * time.Second - + // Authorization code flow callback authCodeCallbackPath = "/oauth/callback" authCodeCallbackPort = 19877 - + // User-Agent to match official Kiro IDE kiroUserAgent = "KiroIDE" + + // IDC token refresh headers (matching Kiro IDE behavior) + idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" +) + +// Sentinel errors for OIDC token polling +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrSlowDown = errors.New("slow_down") ) // SSOOIDCClient handles AWS SSO OIDC authentication. @@ -104,7 +114,11 @@ func promptInput(prompt, defaultValue string) string { } else { fmt.Printf("%s: ", prompt) } - input, _ := reader.ReadString('\n') + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return defaultValue + } input = strings.TrimSpace(input) if input == "" { return defaultValue @@ -112,24 +126,32 @@ func promptInput(prompt, defaultValue string) string { return input } -// promptSelect prompts the user to select from options using arrow keys or number input. +// promptSelect prompts the user to select from options using number input. func promptSelect(prompt string, options []string) int { - fmt.Println(prompt) - for i, opt := range options { - fmt.Printf(" %d) %s\n", i+1, opt) - } - fmt.Print("Enter selection (1-", len(options), "): ") - reader := bufio.NewReader(os.Stdin) - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) - // Parse the selection - var selection int - if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { - return 0 // Default to first option + for { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Printf("Enter selection (1-%d): ", len(options)) + + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return 0 // Default to first option on error + } + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) + continue + } + return selection - 1 } - return selection - 1 } // RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. @@ -266,10 +288,10 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli } if json.Unmarshal(respBody, &errResp) == nil { if errResp.Error == "authorization_pending" { - return nil, fmt.Errorf("authorization_pending") + return nil, ErrAuthorizationPending } if errResp.Error == "slow_down" { - return nil, fmt.Errorf("slow_down") + return nil, ErrSlowDown } } log.Debugf("create token failed: %s", string(respBody)) @@ -315,7 +337,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl req.Header.Set("Content-Type", "application/json") req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) req.Header.Set("Connection", "keep-alive") - req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE") + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) req.Header.Set("Accept", "*/*") req.Header.Set("Accept-Language", "*") req.Header.Set("sec-fetch-mode", "cors") @@ -426,12 +448,11 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin case <-time.After(interval): tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "authorization_pending") { + if errors.Is(err, ErrAuthorizationPending) { fmt.Print(".") continue } - if strings.Contains(errStr, "slow_down") { + if errors.Is(err, ErrSlowDown) { interval += 5 * time.Second continue } @@ -639,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, } if json.Unmarshal(respBody, &errResp) == nil { if errResp.Error == "authorization_pending" { - return nil, fmt.Errorf("authorization_pending") + return nil, ErrAuthorizationPending } if errResp.Error == "slow_down" { - return nil, fmt.Errorf("slow_down") + return nil, ErrSlowDown } } log.Debugf("create token failed: %s", string(respBody)) @@ -787,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, case <-time.After(interval): tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "authorization_pending") { + if errors.Is(err, ErrAuthorizationPending) { fmt.Print(".") continue } - if strings.Contains(errStr, "slow_down") { + if errors.Is(err, ErrSlowDown) { interval += 5 * time.Second continue }