|
8 | 8 | "crypto/sha256" |
9 | 9 | "encoding/base64" |
10 | 10 | "encoding/json" |
| 11 | + "errors" |
11 | 12 | "fmt" |
12 | 13 | "html" |
13 | 14 | "io" |
@@ -35,13 +36,22 @@ const ( |
35 | 36 |
|
36 | 37 | // Polling interval |
37 | 38 | pollInterval = 5 * time.Second |
38 | | - |
| 39 | + |
39 | 40 | // Authorization code flow callback |
40 | 41 | authCodeCallbackPath = "/oauth/callback" |
41 | 42 | authCodeCallbackPort = 19877 |
42 | | - |
| 43 | + |
43 | 44 | // User-Agent to match official Kiro IDE |
44 | 45 | kiroUserAgent = "KiroIDE" |
| 46 | + |
| 47 | + // IDC token refresh headers (matching Kiro IDE behavior) |
| 48 | + idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" |
| 49 | +) |
| 50 | + |
| 51 | +// Sentinel errors for OIDC token polling |
| 52 | +var ( |
| 53 | + ErrAuthorizationPending = errors.New("authorization_pending") |
| 54 | + ErrSlowDown = errors.New("slow_down") |
45 | 55 | ) |
46 | 56 |
|
47 | 57 | // SSOOIDCClient handles AWS SSO OIDC authentication. |
@@ -104,32 +114,44 @@ func promptInput(prompt, defaultValue string) string { |
104 | 114 | } else { |
105 | 115 | fmt.Printf("%s: ", prompt) |
106 | 116 | } |
107 | | - input, _ := reader.ReadString('\n') |
| 117 | + input, err := reader.ReadString('\n') |
| 118 | + if err != nil { |
| 119 | + log.Warnf("Error reading input: %v", err) |
| 120 | + return defaultValue |
| 121 | + } |
108 | 122 | input = strings.TrimSpace(input) |
109 | 123 | if input == "" { |
110 | 124 | return defaultValue |
111 | 125 | } |
112 | 126 | return input |
113 | 127 | } |
114 | 128 |
|
115 | | -// promptSelect prompts the user to select from options using arrow keys or number input. |
| 129 | +// promptSelect prompts the user to select from options using number input. |
116 | 130 | func promptSelect(prompt string, options []string) int { |
117 | | - fmt.Println(prompt) |
118 | | - for i, opt := range options { |
119 | | - fmt.Printf(" %d) %s\n", i+1, opt) |
120 | | - } |
121 | | - fmt.Print("Enter selection (1-", len(options), "): ") |
122 | | - |
123 | 131 | reader := bufio.NewReader(os.Stdin) |
124 | | - input, _ := reader.ReadString('\n') |
125 | | - input = strings.TrimSpace(input) |
126 | 132 |
|
127 | | - // Parse the selection |
128 | | - var selection int |
129 | | - if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { |
130 | | - return 0 // Default to first option |
| 133 | + for { |
| 134 | + fmt.Println(prompt) |
| 135 | + for i, opt := range options { |
| 136 | + fmt.Printf(" %d) %s\n", i+1, opt) |
| 137 | + } |
| 138 | + fmt.Printf("Enter selection (1-%d): ", len(options)) |
| 139 | + |
| 140 | + input, err := reader.ReadString('\n') |
| 141 | + if err != nil { |
| 142 | + log.Warnf("Error reading input: %v", err) |
| 143 | + return 0 // Default to first option on error |
| 144 | + } |
| 145 | + input = strings.TrimSpace(input) |
| 146 | + |
| 147 | + // Parse the selection |
| 148 | + var selection int |
| 149 | + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { |
| 150 | + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) |
| 151 | + continue |
| 152 | + } |
| 153 | + return selection - 1 |
131 | 154 | } |
132 | | - return selection - 1 |
133 | 155 | } |
134 | 156 |
|
135 | 157 | // RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. |
@@ -266,10 +288,10 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli |
266 | 288 | } |
267 | 289 | if json.Unmarshal(respBody, &errResp) == nil { |
268 | 290 | if errResp.Error == "authorization_pending" { |
269 | | - return nil, fmt.Errorf("authorization_pending") |
| 291 | + return nil, ErrAuthorizationPending |
270 | 292 | } |
271 | 293 | if errResp.Error == "slow_down" { |
272 | | - return nil, fmt.Errorf("slow_down") |
| 294 | + return nil, ErrSlowDown |
273 | 295 | } |
274 | 296 | } |
275 | 297 | log.Debugf("create token failed: %s", string(respBody)) |
@@ -315,7 +337,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl |
315 | 337 | req.Header.Set("Content-Type", "application/json") |
316 | 338 | req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) |
317 | 339 | req.Header.Set("Connection", "keep-alive") |
318 | | - req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE") |
| 340 | + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) |
319 | 341 | req.Header.Set("Accept", "*/*") |
320 | 342 | req.Header.Set("Accept-Language", "*") |
321 | 343 | req.Header.Set("sec-fetch-mode", "cors") |
@@ -426,12 +448,11 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin |
426 | 448 | case <-time.After(interval): |
427 | 449 | tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) |
428 | 450 | if err != nil { |
429 | | - errStr := err.Error() |
430 | | - if strings.Contains(errStr, "authorization_pending") { |
| 451 | + if errors.Is(err, ErrAuthorizationPending) { |
431 | 452 | fmt.Print(".") |
432 | 453 | continue |
433 | 454 | } |
434 | | - if strings.Contains(errStr, "slow_down") { |
| 455 | + if errors.Is(err, ErrSlowDown) { |
435 | 456 | interval += 5 * time.Second |
436 | 457 | continue |
437 | 458 | } |
@@ -639,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, |
639 | 660 | } |
640 | 661 | if json.Unmarshal(respBody, &errResp) == nil { |
641 | 662 | if errResp.Error == "authorization_pending" { |
642 | | - return nil, fmt.Errorf("authorization_pending") |
| 663 | + return nil, ErrAuthorizationPending |
643 | 664 | } |
644 | 665 | if errResp.Error == "slow_down" { |
645 | | - return nil, fmt.Errorf("slow_down") |
| 666 | + return nil, ErrSlowDown |
646 | 667 | } |
647 | 668 | } |
648 | 669 | log.Debugf("create token failed: %s", string(respBody)) |
@@ -787,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, |
787 | 808 | case <-time.After(interval): |
788 | 809 | tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) |
789 | 810 | if err != nil { |
790 | | - errStr := err.Error() |
791 | | - if strings.Contains(errStr, "authorization_pending") { |
| 811 | + if errors.Is(err, ErrAuthorizationPending) { |
792 | 812 | fmt.Print(".") |
793 | 813 | continue |
794 | 814 | } |
795 | | - if strings.Contains(errStr, "slow_down") { |
| 815 | + if errors.Is(err, ErrSlowDown) { |
796 | 816 | interval += 5 * time.Second |
797 | 817 | continue |
798 | 818 | } |
|
0 commit comments