diff --git a/internal/commands/root.go b/internal/commands/root.go index 453070073..91f344304 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -114,6 +114,7 @@ func NewAstCLI( // This monitors and traps situations where "extra/garbage" commands // are passed to Cobra. rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + CheckPreferredCredentials(cmd) err := customLogConfiguration(rootCmd) if err != nil { return err @@ -416,3 +417,17 @@ func setLogOutputFromFlag(flag, dirPath string) error { log.SetOutput(multiWriter) return nil } +func CheckPreferredCredentials(cmd *cobra.Command) string { + if cmd.Flags().Changed(params.AccessKeyIDFlag) && + cmd.Flags().Changed(params.AccessKeySecretFlag) { + viper.Set(params.PreferredCredentialTypeKey, "access_key") + return "access_key" + } else if cmd.Flags().Changed(params.AstAPIKeyFlag) { + viper.Set(params.PreferredCredentialTypeKey, "apikey") + return "apikey" + } else { + viper.Set(params.PreferredCredentialTypeKey, "") + } + result := viper.GetString(params.PreferredCredentialTypeKey) + return result +} diff --git a/internal/commands/root_test.go b/internal/commands/root_test.go index 654594c5b..149cc4ba8 100644 --- a/internal/commands/root_test.go +++ b/internal/commands/root_test.go @@ -160,6 +160,7 @@ func TestCreateCommand_WithInvalidFlag_ShouldReturnExitCode1(t *testing.T) { func executeTestCommand(cmd *cobra.Command, args ...string) error { fmt.Println("Executing command with args ", args) + defer viper.Reset() cmd.SetArgs(args) cmd.SilenceUsage = true return cmd.Execute() diff --git a/internal/params/flags.go b/internal/params/flags.go index 13f43f4bd..16a911b06 100644 --- a/internal/params/flags.go +++ b/internal/params/flags.go @@ -55,6 +55,7 @@ const ( AccessKeyIDFlag = "client-id" AccessKeySecretFlag = "client-secret" AccessKeyIDFlagUsage = "The OAuth2 client ID" + PreferredCredentialTypeKey = "preferred_credential_type" AccessKeySecretFlagUsage = "The OAuth2 client secret" InsecureFlag = "insecure" InsecureFlagUsage = "Ignore TLS certificate validations" diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index bf20407f2..b3f5a0775 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -600,6 +600,7 @@ func configureClientCredentialsAndGetNewToken() (string, error) { accessKeySecret := viper.GetString(commonParams.AccessKeySecretConfigKey) astAPIKey := viper.GetString(commonParams.AstAPIKey) var accessToken string + credType := viper.GetString(commonParams.PreferredCredentialTypeKey) if accessKeyID == "" && astAPIKey == "" { return "", errors.Errorf(FailedToAuth, "access key ID") @@ -612,10 +613,18 @@ func configureClientCredentialsAndGetNewToken() (string, error) { return "", err } - if astAPIKey != "" { - accessToken, err = getNewToken(getAPIKeyPayload(astAPIKey), authURI) - } else { - accessToken, err = getNewToken(getCredentialsPayload(accessKeyID, accessKeySecret), authURI) + if astAPIKey != "" && credType == "apikey" { + accessToken, err = getNewToken( + getAPIKeyPayload(astAPIKey), authURI) + } else if accessKeyID != "" && accessKeySecret != "" && credType == "access_key" { + accessToken, err = getNewToken( + getCredentialsPayload(accessKeyID, accessKeySecret), authURI) + } else if astAPIKey != "" { + accessToken, err = getNewToken( + getAPIKeyPayload(astAPIKey), authURI) + } else if accessKeyID != "" && accessKeySecret != "" { + accessToken, err = getNewToken( + getCredentialsPayload(accessKeyID, accessKeySecret), authURI) } if err != nil { diff --git a/test/integration/configuration_test.go b/test/integration/configuration_test.go index d38c91faa..e1e3398da 100644 --- a/test/integration/configuration_test.go +++ b/test/integration/configuration_test.go @@ -73,6 +73,9 @@ func TestSetConfigProperty_EnvVarConfigFilePath(t *testing.T) { err, _ = executeCommand(t, "configure", "set", "--prop-name", "cx_client_id", "--prop-value", "example_client_id") assert.NilError(t, err) + defer func() { + executeCommand(t, "configure", "set", "--prop-name", "cx_client_id", "--prop-value", "") + }() } func TestLoadConfiguration_ConfigFilePathFlag(t *testing.T) { @@ -100,6 +103,9 @@ func TestSetConfigProperty_ConfigFilePathFlag(t *testing.T) { err, _ = executeCommand(t, "configure", "set", "--prop-name", "cx_client_id", "--prop-value", "example_client_id", "--config-file-path", filePath) assert.NilError(t, err) + defer func() { + executeCommand(t, "configure", "set", "--prop-name", "cx_client_id", "--prop-value", "") + }() } func TestLoadConfiguration_ConfigFilePathFlagFileWithoutPermission(t *testing.T) {