Skip to content

Commit 349b2ba

Browse files
committed
refactor: improve error handling and code quality
- Handle errors in promptInput instead of ignoring them - Improve promptSelect to provide feedback on invalid input and re-prompt - Use sentinel errors (ErrAuthorizationPending, ErrSlowDown) instead of string-based error checking with strings.Contains - Move hardcoded x-amz-user-agent header to idcAmzUserAgent constant Addresses code review feedback from Gemini Code Assist.
1 parent 98db5aa commit 349b2ba

File tree

1 file changed

+48
-28
lines changed

1 file changed

+48
-28
lines changed

internal/auth/kiro/sso_oidc.go

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"crypto/sha256"
99
"encoding/base64"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"html"
1314
"io"
@@ -35,13 +36,22 @@ const (
3536

3637
// Polling interval
3738
pollInterval = 5 * time.Second
38-
39+
3940
// Authorization code flow callback
4041
authCodeCallbackPath = "/oauth/callback"
4142
authCodeCallbackPort = 19877
42-
43+
4344
// User-Agent to match official Kiro IDE
4445
kiroUserAgent = "KiroIDE"
46+
47+
// IDC token refresh headers (matching Kiro IDE behavior)
48+
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
49+
)
50+
51+
// Sentinel errors for OIDC token polling
52+
var (
53+
ErrAuthorizationPending = errors.New("authorization_pending")
54+
ErrSlowDown = errors.New("slow_down")
4555
)
4656

4757
// SSOOIDCClient handles AWS SSO OIDC authentication.
@@ -104,32 +114,44 @@ func promptInput(prompt, defaultValue string) string {
104114
} else {
105115
fmt.Printf("%s: ", prompt)
106116
}
107-
input, _ := reader.ReadString('\n')
117+
input, err := reader.ReadString('\n')
118+
if err != nil {
119+
log.Warnf("Error reading input: %v", err)
120+
return defaultValue
121+
}
108122
input = strings.TrimSpace(input)
109123
if input == "" {
110124
return defaultValue
111125
}
112126
return input
113127
}
114128

115-
// promptSelect prompts the user to select from options using arrow keys or number input.
129+
// promptSelect prompts the user to select from options using number input.
116130
func promptSelect(prompt string, options []string) int {
117-
fmt.Println(prompt)
118-
for i, opt := range options {
119-
fmt.Printf(" %d) %s\n", i+1, opt)
120-
}
121-
fmt.Print("Enter selection (1-", len(options), "): ")
122-
123131
reader := bufio.NewReader(os.Stdin)
124-
input, _ := reader.ReadString('\n')
125-
input = strings.TrimSpace(input)
126132

127-
// Parse the selection
128-
var selection int
129-
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
130-
return 0 // Default to first option
133+
for {
134+
fmt.Println(prompt)
135+
for i, opt := range options {
136+
fmt.Printf(" %d) %s\n", i+1, opt)
137+
}
138+
fmt.Printf("Enter selection (1-%d): ", len(options))
139+
140+
input, err := reader.ReadString('\n')
141+
if err != nil {
142+
log.Warnf("Error reading input: %v", err)
143+
return 0 // Default to first option on error
144+
}
145+
input = strings.TrimSpace(input)
146+
147+
// Parse the selection
148+
var selection int
149+
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
150+
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
151+
continue
152+
}
153+
return selection - 1
131154
}
132-
return selection - 1
133155
}
134156

135157
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
@@ -266,10 +288,10 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
266288
}
267289
if json.Unmarshal(respBody, &errResp) == nil {
268290
if errResp.Error == "authorization_pending" {
269-
return nil, fmt.Errorf("authorization_pending")
291+
return nil, ErrAuthorizationPending
270292
}
271293
if errResp.Error == "slow_down" {
272-
return nil, fmt.Errorf("slow_down")
294+
return nil, ErrSlowDown
273295
}
274296
}
275297
log.Debugf("create token failed: %s", string(respBody))
@@ -315,7 +337,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
315337
req.Header.Set("Content-Type", "application/json")
316338
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
317339
req.Header.Set("Connection", "keep-alive")
318-
req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE")
340+
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
319341
req.Header.Set("Accept", "*/*")
320342
req.Header.Set("Accept-Language", "*")
321343
req.Header.Set("sec-fetch-mode", "cors")
@@ -426,12 +448,11 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
426448
case <-time.After(interval):
427449
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
428450
if err != nil {
429-
errStr := err.Error()
430-
if strings.Contains(errStr, "authorization_pending") {
451+
if errors.Is(err, ErrAuthorizationPending) {
431452
fmt.Print(".")
432453
continue
433454
}
434-
if strings.Contains(errStr, "slow_down") {
455+
if errors.Is(err, ErrSlowDown) {
435456
interval += 5 * time.Second
436457
continue
437458
}
@@ -639,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
639660
}
640661
if json.Unmarshal(respBody, &errResp) == nil {
641662
if errResp.Error == "authorization_pending" {
642-
return nil, fmt.Errorf("authorization_pending")
663+
return nil, ErrAuthorizationPending
643664
}
644665
if errResp.Error == "slow_down" {
645-
return nil, fmt.Errorf("slow_down")
666+
return nil, ErrSlowDown
646667
}
647668
}
648669
log.Debugf("create token failed: %s", string(respBody))
@@ -787,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
787808
case <-time.After(interval):
788809
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
789810
if err != nil {
790-
errStr := err.Error()
791-
if strings.Contains(errStr, "authorization_pending") {
811+
if errors.Is(err, ErrAuthorizationPending) {
792812
fmt.Print(".")
793813
continue
794814
}
795-
if strings.Contains(errStr, "slow_down") {
815+
if errors.Is(err, ErrSlowDown) {
796816
interval += 5 * time.Second
797817
continue
798818
}

0 commit comments

Comments
 (0)