diff --git a/internal/commands/predicates.go b/internal/commands/predicates.go index 7b212c705..f0501291a 100644 --- a/internal/commands/predicates.go +++ b/internal/commands/predicates.go @@ -1,11 +1,16 @@ package commands import ( + "fmt" + "sort" "strings" "time" + "encoding/json" + "github.com/MakeNowJust/heredoc" "github.com/checkmarx/ast-cli/internal/commands/util/printer" + "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/pkg/errors" @@ -92,8 +97,8 @@ func triageShowSubCommand(resultsPredicatesWrapper wrappers.ResultsPredicatesWra triageShowCmd.PersistentFlags().String(params.SimilarityIDFlag, "", "Similarity ID") triageShowCmd.PersistentFlags().String(params.ProjectIDFlag, "", "Project ID") triageShowCmd.PersistentFlags().String(params.ScanTypeFlag, "", "Scan Type") + triageShowCmd.PersistentFlags().StringSlice(params.VulnerabilitiesFlag, []string{}, "SCA Vulnerabilities details") - markFlagAsRequired(triageShowCmd, params.SimilarityIDFlag) markFlagAsRequired(triageShowCmd, params.ProjectIDFlag) markFlagAsRequired(triageShowCmd, params.ScanTypeFlag) @@ -117,6 +122,7 @@ func triageUpdateSubCommand(resultsPredicatesWrapper wrappers.ResultsPredicatesW --scan-type `, ), + RunE: runTriageUpdate(resultsPredicatesWrapper, featureFlagsWrapper, customStatesWrapper), } @@ -127,9 +133,8 @@ func triageUpdateSubCommand(resultsPredicatesWrapper wrappers.ResultsPredicatesW triageUpdateCmd.PersistentFlags().Int(params.CustomStateIDFlag, -1, "Specify the ID of the states that you would like to apply to this result") triageUpdateCmd.PersistentFlags().String(params.CommentFlag, "", "Optional comment") triageUpdateCmd.PersistentFlags().String(params.ScanTypeFlag, "", "Scan Type") + triageUpdateCmd.PersistentFlags().StringSlice(params.VulnerabilitiesFlag, []string{}, "SCA Vulnerabilities details") - markFlagAsRequired(triageUpdateCmd, params.SimilarityIDFlag) - markFlagAsRequired(triageUpdateCmd, params.SeverityFlag) markFlagAsRequired(triageUpdateCmd, params.ProjectIDFlag) markFlagAsRequired(triageUpdateCmd, params.ScanTypeFlag) @@ -145,38 +150,48 @@ func runTriageShow(resultsPredicatesWrapper wrappers.ResultsPredicatesWrapper) f similarityID, _ := cmd.Flags().GetString(params.SimilarityIDFlag) scanType, _ := cmd.Flags().GetString(params.ScanTypeFlag) projectID, _ := cmd.Flags().GetString(params.ProjectIDFlag) - + vulnerabilityDetails, _ := cmd.Flags().GetStringSlice(params.VulnerabilitiesFlag) projectIDs := strings.Split(projectID, ",") if len(projectIDs) > 1 { return errors.Errorf("%s", "Multiple project-ids are not allowed.") } - predicatesCollection, errorModel, err = resultsPredicatesWrapper.GetAllPredicatesForSimilarityID( - similarityID, - projectID, - scanType, - ) - - if err != nil { - return errors.Wrapf(err, "%s", "Failed showing the predicate") - } - - // Checking the response - if errorModel != nil { - return errors.Errorf( - "%s: CODE: %d, %s", - "Failed showing the predicate.", - errorModel.Code, - errorModel.Message, - ) - } else if predicatesCollection != nil { - err = printByFormat(cmd, toPredicatesView(*predicatesCollection)) + if strings.EqualFold(strings.ToLower(strings.TrimSpace(scanType)), params.ScaType) { + if len(vulnerabilityDetails) == 0 { + return errors.Errorf("%s", "Failed showing the predicate. Vulnerabilities are required for SCA triage") + } + scaPredicates, err := resultsPredicatesWrapper.GetScaPredicates(vulnerabilityDetails, projectID) + if err != nil { + return errors.Wrapf(err, "%s", "Failed showing the predicate") + } + err = printByFormat(cmd, toScaPredicateResultView(scaPredicates)) if err != nil { return err } + return nil + } else { + predicatesCollection, errorModel, err = resultsPredicatesWrapper.GetAllPredicatesForSimilarityID(similarityID, projectID, scanType) + if err != nil { + return errors.Wrapf(err, "%s", "Failed showing the predicate") + } + // Checking the response + if errorModel != nil { + return errors.Errorf( + "%s: CODE: %d, %s", + "Failed showing the predicate.", + errorModel.Code, + errorModel.Message, + ) + } else if predicatesCollection != nil { + err = printByFormat(cmd, toPredicatesView(*predicatesCollection)) + if err != nil { + return err + } + } + + return nil } - return nil } } @@ -191,38 +206,124 @@ func runTriageUpdate(resultsPredicatesWrapper wrappers.ResultsPredicatesWrapper, scanType, _ := cmd.Flags().GetString(params.ScanTypeFlag) // check if the current tenant has critical severity available flagResponse, _ := wrappers.GetSpecificFeatureFlag(featureFlagsWrapper, wrappers.CVSSV3Enabled) + vulnerabilityDetails, _ := cmd.Flags().GetStringSlice(params.VulnerabilitiesFlag) + criticalEnabled := flagResponse.Status if !criticalEnabled && strings.EqualFold(severity, "critical") { return errors.Errorf("%s", "Critical severity is not available for your tenant.This severity status will be enabled shortly") } - var err error state, customStateID, err = determineSystemOrCustomState(customStatesWrapper, featureFlagsWrapper, state, customStateID) if err != nil { - return err + return errors.Wrapf(err, "%s", "Failed updating the predicate") } - predicate := &wrappers.PredicateRequest{ + predicate, err := preparePredicateRequest(vulnerabilityDetails, similarityID, projectID, severity, state, customStateID, comment, scanType) + if err != nil { + return errors.Wrapf(err, "%s", "Failed updating the predicate") + } + _, err = resultsPredicatesWrapper.PredicateSeverityAndState(predicate, scanType) + if err != nil { + return errors.Wrapf(err, "%s", "Failed updating the predicate") + } + return nil + } +} + +func preparePredicateRequest(vulnerabilityDetails []string, similarityID, projectID, severity, state string, customStateID int, comment, scanType string) (interface{}, error) { + scanType = strings.ToLower(scanType) + scanType = strings.TrimSpace(scanType) + if strings.EqualFold(scanType, Sca) { + state = transformState(state) + payload, err := prepareScaTriagePayload(vulnerabilityDetails, comment, state, projectID) + if err != nil { + return nil, err + } + return payload, nil + } else { + payload := &wrappers.PredicateRequest{ SimilarityID: similarityID, ProjectID: projectID, Severity: severity, Comment: comment, } - if state != "" { - predicate.State = &state + payload.State = &state } else { - predicate.CustomStateID = &customStateID + payload.CustomStateID = &customStateID } + return payload, nil + } +} - _, err = resultsPredicatesWrapper.PredicateSeverityAndState(predicate, scanType) +func transformState(state string) string { + state = strings.ToLower(strings.TrimSpace(state)) + switch state { + case strings.ToLower(params.ToVerify): + return wrappers.ToVerify + case strings.ToLower(params.URGENT): + return wrappers.Urgent + case strings.ToLower(params.NotExploitable): + return wrappers.NotExploitable + case strings.ToLower(params.ProposedNotExploitable): + return wrappers.ProposedNotExploitable + case strings.ToLower(params.CONFIRMED): + return wrappers.Confirmed + } + return "" +} + +func prepareScaTriagePayload(vulnerabilityDetails []string, comment, state, projectID string) (interface{}, error) { + if len(vulnerabilityDetails) == 0 { + return nil, errors.Errorf("Vulnerabilities details are required.") + } + scaTriageInfo := make(map[string]interface{}) + for _, vulnerability := range vulnerabilityDetails { + vulnerabilityKeyVal := strings.Split(vulnerability, "=") + err := validateVulnerabilityDetails(vulnerabilityKeyVal) if err != nil { - return errors.Wrapf(err, "%s", "Failed updating the predicate") + return nil, err } + scaTriageInfo[strings.TrimSpace(vulnerabilityKeyVal[0])] = strings.TrimSpace(vulnerabilityKeyVal[1]) + } - return nil + if scaTriageInfo["packageName"] == nil && scaTriageInfo["packagename"] == nil { + return nil, errors.Errorf("Package name is required") } + if scaTriageInfo["packageVersion"] == nil && scaTriageInfo["packageversion"] == nil { + return nil, errors.Errorf("Package version is required") + } + if scaTriageInfo["packageManager"] == nil && scaTriageInfo["packagemanager"] == nil { + return nil, errors.Errorf("Package manager is required") + } + + scaTriageInfo["projectIds"] = []string{projectID} + actionInfo := make(map[string]interface{}) + actionInfo["actionType"] = params.ChangeState + actionInfo["value"] = state + actionInfo["comment"] = comment + scaTriageInfo["actions"] = []map[string]interface{}{actionInfo} + b, err := json.Marshal(scaTriageInfo) + if err != nil { + logger.PrintIfVerbose(fmt.Sprintf("Failed to serialize vulnerabilities %s", scaTriageInfo)) + return nil, errors.Errorf("Failed to prepare SCA triage request") + } + payload := wrappers.ScaPredicateRequest{} + err = json.Unmarshal(b, &payload) + if err != nil { + logger.PrintIfVerbose(fmt.Sprintf("Failed to deserialize vulnerabilities %s", string(b))) + return nil, errors.Errorf("Failed to prepare SCA triage request") + } + return payload, nil +} + +func validateVulnerabilityDetails(vulnerability []string) error { + if len(vulnerability) != params.KeyValuePairSize { + return errors.Errorf("Invalid vulnerabilities. It should be in a KEY=VALUE format") + } + return nil } + func determineSystemOrCustomState(customStatesWrapper wrappers.CustomStatesWrapper, featureFlagsWrapper wrappers.FeatureFlagsWrapper, state string, customStateID int) (string, int, error) { if !isCustomState(state) { return state, -1, nil @@ -284,6 +385,41 @@ type predicateView struct { CreatedAt time.Time `format:"name:Created at;time:01-02-06 15:04:05"` } +type scaPredicateResultView struct { + VulnerabilityID string `format:"name:Vulnerability ID"` + PackageName string `format:"name:Package Name"` + PackageVersion string `format:"name:Package Version"` + PackageManager string `format:"name:Package Manager"` + Comment string `format:"name:Comment"` + State string `format:"name:State"` + CreatedBy string `format:"name:Created By"` + CreatedAt time.Time `format:"name:Created at;time:01-02-06 15:04:05"` +} + +func toScaPredicateResultView(scaPredicateResult *wrappers.ScaPredicateResult) []scaPredicateResultView { + view := []scaPredicateResultView{} + if len(scaPredicateResult.Actions) > 0 { + for _, action := range scaPredicateResult.Actions { + view = append(view, scaPredicateResultView{ + VulnerabilityID: scaPredicateResult.Context.VulnerabilityID, + PackageName: scaPredicateResult.Context.PackageName, + PackageVersion: scaPredicateResult.Context.PackageVersion, + PackageManager: scaPredicateResult.Context.PackageManager, + Comment: action.Message, + State: action.ActionValue, + CreatedBy: action.UserName, + CreatedAt: action.CreatedAt, + }) + } + } + + sort.Slice(view, func(i, j int) bool { + return view[i].CreatedAt.After(view[j].CreatedAt) + }) + + return view +} + func toPredicatesView(predicatesCollection wrappers.PredicatesCollectionResponseModel) []predicateView { projectPredicatesCollection := predicatesCollection.PredicateHistoryPerProject diff --git a/internal/commands/predicates_test.go b/internal/commands/predicates_test.go index 94890dcda..c7d49461c 100644 --- a/internal/commands/predicates_test.go +++ b/internal/commands/predicates_test.go @@ -5,6 +5,7 @@ package commands import ( "fmt" "testing" + "time" "github.com/checkmarx/ast-cli/internal/wrappers" "github.com/checkmarx/ast-cli/internal/wrappers/mock" @@ -12,6 +13,8 @@ import ( "gotest.tools/assert" ) +var requiredFlagsError = "required flag(s) \"project-id\", \"scan-type\" not set" + func TestTriageHelp(t *testing.T) { execCmdNilAssertion(t, "help", "triage") } @@ -41,7 +44,7 @@ func TestRunUpdateTriageCommand(t *testing.T) { func TestRunShowTriageCommandWithNoInput(t *testing.T) { err := execCmdNotNilAssertion(t, "triage", "show") - assert.Assert(t, err.Error() == "required flag(s) \"project-id\", \"scan-type\", \"similarity-id\" not set") + assert.Assert(t, err.Error() == requiredFlagsError) } func TestRunUpdateTriageCommandWithNoInput(t *testing.T) { @@ -49,7 +52,7 @@ func TestRunUpdateTriageCommandWithNoInput(t *testing.T) { fmt.Println(err) assert.Assert( t, - err.Error() == "required flag(s) \"project-id\", \"scan-type\", \"severity\", \"similarity-id\" not set") + err.Error() == requiredFlagsError) } func TestTriageGetStatesFlag(t *testing.T) { @@ -339,3 +342,383 @@ func TestDetermineSystemOrCustomState(t *testing.T) { }) } } + +func TestPrepareScaTriagePayload(t *testing.T) { + tests := []struct { + name string + vulnerabilityDetails []string + comment string + state string + projectID string + expectedError string + }{ + { + name: "Missing packageName", + vulnerabilityDetails: []string{ + "packageVersion=4.17.20", + "packageManager=npm", + "vulnerabilityId=CVE-2021-23337", + }, + comment: "Testing missing package name", + state: "NOT_EXPLOITABLE", + projectID: "test-project-123", + expectedError: "Package name is required", + }, + { + name: "Missing packageVersion", + vulnerabilityDetails: []string{ + "packageName=lodash", + "packageManager=npm", + "vulnerabilityId=CVE-2021-23337", + }, + comment: "Testing missing package version", + state: "NOT_EXPLOITABLE", + projectID: "test-project-123", + expectedError: "Package version is required", + }, + { + name: "Missing packageManager", + vulnerabilityDetails: []string{ + "packageName=lodash", + "packageVersion=4.17.20", + "vulnerabilityId=CVE-2021-23337", + }, + comment: "Testing missing package manager", + state: "NOT_EXPLOITABLE", + projectID: "test-project-123", + expectedError: "Package manager is required", + }, + { + name: "Invalid vulnerability format - no equals sign", + vulnerabilityDetails: []string{ + "packageNamelodash", + "packageVersion=4.17.20", + "packageManager=npm", + }, + comment: "Testing invalid format", + state: "NOT_EXPLOITABLE", + projectID: "test-project-123", + expectedError: "Invalid vulnerabilities. It should be in a KEY=VALUE format", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + payload, err := prepareScaTriagePayload(tt.vulnerabilityDetails, tt.comment, tt.state, tt.projectID) + if tt.expectedError != "" { + assert.ErrorContains(t, err, tt.expectedError) + } else { + assert.NilError(t, err) + assert.Assert(t, payload != nil, "Expected payload to be non-nil") + } + }) + } +} + +func TestPrepareScaTriagePayloadWithMissingVulnerabilities(t *testing.T) { + payload, err := prepareScaTriagePayload(nil, "Testing missing vulnerabilities", "NOT_EXPLOITABLE", "test-project-123") + assert.ErrorContains(t, err, "Vulnerabilities details are required.") + assert.Assert(t, payload == nil, "Expected payload to be nil") +} + +func TestRunShowTriageCommandForSCAWithMissingVulnerabilities(t *testing.T) { + err := execCmdNotNilAssertion( + t, + "triage", + "show", + "--project-id", + "MOCK", + "--scan-type", + "sca", + ) + // SCA triage show requires vulnerabilities flag + assert.Assert(t, err != nil, "Expected error when vulnerabilities flag is missing") +} + +func TestRunShowTriageCommandForSCAWithMultipleProjects(t *testing.T) { + err := execCmdNotNilAssertion( + t, + "triage", + "show", + "--project-id", + "MOCK1,MOCK2", + "--scan-type", + "sca", + "--vulnerability-identifiers", + "packageName=lodash,packageVersion=4.17.20,packageManager=npm", + ) + assert.ErrorContains(t, err, "Multiple project-ids are not allowed") +} + +func TestToScaPredicateResultView(t *testing.T) { + // Arrange: Create sample SCA predicate result + createdAt1, _ := time.Parse(time.RFC3339, "2024-01-15T10:00:00Z") + createdAt2, _ := time.Parse(time.RFC3339, "2024-01-16T12:00:00Z") + + scaPredicateResult := &wrappers.ScaPredicateResult{ + Context: wrappers.Context{ + VulnerabilityID: "CVE-2021-23337", + PackageName: "lodash", + PackageVersion: "4.17.20", + PackageManager: "npm", + }, + Actions: []wrappers.Action{ + { + ActionType: "ChangeState", + ActionValue: "NOT_EXPLOITABLE", + Message: "This is not exploitable in our context", + UserName: "test-user", + CreatedAt: createdAt1, + Enabled: true, + }, + { + ActionType: "ChangeState", + ActionValue: "CONFIRMED", + Message: "Actually, this needs to be fixed", + UserName: "test-user-2", + CreatedAt: createdAt2, + Enabled: true, + }, + }, + } + + // Act: Call the toScaPredicateResultView function + result := toScaPredicateResultView(scaPredicateResult) + + // Assert: Verify the conversion + assert.Equal(t, len(result), 2, "Expected 2 predicate result views") + + // Check first action + assert.Equal(t, result[1].VulnerabilityID, "CVE-2021-23337") + assert.Equal(t, result[1].PackageName, "lodash") + assert.Equal(t, result[1].PackageVersion, "4.17.20") + assert.Equal(t, result[1].PackageManager, "npm") + assert.Equal(t, result[1].State, "NOT_EXPLOITABLE") + assert.Equal(t, result[1].Comment, "This is not exploitable in our context") + assert.Equal(t, result[1].CreatedBy, "test-user") + assert.Equal(t, result[1].CreatedAt, createdAt1) + + // Check second action + assert.Equal(t, result[0].State, "CONFIRMED") + assert.Equal(t, result[0].Comment, "Actually, this needs to be fixed") + assert.Equal(t, result[0].CreatedBy, "test-user-2") +} + +func TestToScaPredicateResultView_EmptyActions(t *testing.T) { + // Arrange: Create SCA predicate result with no actions + scaPredicateResult := &wrappers.ScaPredicateResult{ + Context: wrappers.Context{ + VulnerabilityID: "CVE-2021-23337", + PackageName: "lodash", + PackageVersion: "4.17.20", + PackageManager: "npm", + }, + Actions: []wrappers.Action{}, + } + + // Act: Call the toScaPredicateResultView function + result := toScaPredicateResultView(scaPredicateResult) + + // Assert: Verify empty result + assert.Equal(t, len(result), 0, "Expected empty predicate result views") +} + +func TestTransformState(t *testing.T) { + tests := []struct { + name string + inputState string + expectedState string + }{ + { + name: "TO_VERIFY uppercase", + inputState: "TO_VERIFY", + expectedState: "ToVerify", + }, + { + name: "to_verify lowercase", + inputState: "to_verify", + expectedState: "ToVerify", + }, + { + name: "NOT_EXPLOITABLE uppercase", + inputState: "NOT_EXPLOITABLE", + expectedState: "NotExploitable", + }, + { + name: "not_exploitable lowercase", + inputState: "not_exploitable", + expectedState: "NotExploitable", + }, + { + name: "PROPOSED_NOT_EXPLOITABLE uppercase", + inputState: "PROPOSED_NOT_EXPLOITABLE", + expectedState: "ProposedNotExploitable", + }, + { + name: "CONFIRMED uppercase", + inputState: "CONFIRMED", + expectedState: "Confirmed", + }, + { + name: "URGENT uppercase", + inputState: "URGENT", + expectedState: "Urgent", + }, + { + name: "State with whitespace", + inputState: " TO_VERIFY ", + expectedState: "ToVerify", + }, + { + name: "Unknown state", + inputState: "CUSTOM_STATE", + expectedState: "", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + result := transformState(tt.inputState) + assert.Equal(t, result, tt.expectedState) + }) + } +} + +//nolint:goconst +func TestPrepareScaTriagePayloadWithValidData(t *testing.T) { + vulnerabilityDetails := []string{ + "packageName=lodash", + "packageVersion=4.17.20", + "packageManager=npm", + "vulnerabilityId=CVE-2021-23337", + } + comment := "This is a test comment" + state := "NOT_EXPLOITABLE" + projectID := "test-project-123" + + payload, err := prepareScaTriagePayload(vulnerabilityDetails, comment, state, projectID) + + assert.NilError(t, err) + assert.Assert(t, payload != nil, "Expected payload to be non-nil") +} + +func TestPrepareScaTriagePayloadWithCaseInsensitiveFields(t *testing.T) { + tests := []struct { + name string + vulnerabilityDetails []string + shouldSucceed bool + expectedError string + }{ + { + name: "Lowercase package fields", + vulnerabilityDetails: []string{ + "packagename=lodash", + "packageversion=4.17.20", + "packagemanager=npm", + }, + shouldSucceed: true, + }, + { + name: "Mixed case package fields", + vulnerabilityDetails: []string{ + "packageName=lodash", + "packageversion=4.17.20", + "packageManager=npm", + }, + shouldSucceed: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + payload, err := prepareScaTriagePayload(tt.vulnerabilityDetails, "test comment", "NOT_EXPLOITABLE", "project-123") + if tt.shouldSucceed { + assert.NilError(t, err) + assert.Assert(t, payload != nil, "Expected payload to be non-nil") + } else { + assert.ErrorContains(t, err, tt.expectedError) + } + }) + } +} + +func TestValidateVulnerabilityDetails(t *testing.T) { + tests := []struct { + name string + vulnerability []string + expectError bool + expectedError string + }{ + { + name: "Valid key-value pair", + vulnerability: []string{"packageName", "lodash"}, + expectError: false, + }, + { + name: "Invalid - no value", + vulnerability: []string{"packageName"}, + expectError: true, + expectedError: "Invalid vulnerabilities. It should be in a KEY=VALUE format", + }, + { + name: "Invalid - too many values", + vulnerability: []string{"packageName", "lodash", "extra"}, + expectError: true, + expectedError: "Invalid vulnerabilities. It should be in a KEY=VALUE format", + }, + { + name: "Invalid - empty array", + vulnerability: []string{}, + expectError: true, + expectedError: "Invalid vulnerabilities. It should be in a KEY=VALUE format", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := validateVulnerabilityDetails(tt.vulnerability) + if tt.expectError { + assert.ErrorContains(t, err, tt.expectedError) + } else { + assert.NilError(t, err) + } + }) + } +} + +func TestRunUpdateTriageCommandForSCA(t *testing.T) { + mockResultsPredicatesWrapper := &mock.ResultsPredicatesMockWrapper{} + mockFeatureFlagsWrapper := &mock.FeatureFlagsMockWrapper{} + mockCustomStatesWrapper := &mock.CustomStatesMockWrapper{} + + cmd := triageUpdateSubCommand(mockResultsPredicatesWrapper, mockFeatureFlagsWrapper, mockCustomStatesWrapper) + cmd.SetArgs([]string{ + "--project-id", "MOCK", + "--state", "NOT_EXPLOITABLE", + "--scan-type", "sca", + "--vulnerability-identifiers", "packageName=lodash,packageVersion=4.17.20,packageManager=npm", + "--comment", "Testing SCA update", + }) + + err := cmd.Execute() + assert.NilError(t, err) +} + +func TestPrepareScaTriagePayloadWithMissingVulnerabilityId(t *testing.T) { + vulnerabilityDetails := []string{ + "packageName=lodash", + "packageVersion=4.17.20", + "packageManager=npm", + } + comment := "Testing without vulnerability ID" + state := "NOT_EXPLOITABLE" + projectID := "test-project-123" + + payload, err := prepareScaTriagePayload(vulnerabilityDetails, comment, state, projectID) + assert.NilError(t, err) + assert.Assert(t, payload != nil, "Expected payload to be non-nil even without vulnerabilityId") +} diff --git a/internal/params/binds.go b/internal/params/binds.go index 4dd138e19..410de40d3 100644 --- a/internal/params/binds.go +++ b/internal/params/binds.go @@ -28,6 +28,7 @@ var EnvVarsBinds = []struct { {ScsScanOverviewPathKey, ScsScanOverviewPathEnv, "api/micro-engines/read/scans/%s/scan-overview"}, {SastResultsPathKey, SastResultsPathEnv, "api/sast-results"}, {SastResultsPredicatesPathKey, SastResultsPredicatesPathEnv, "api/sast-results-predicates"}, + {ScaResultsPredicatesPathKey, ScaResultsPredicatesPathEnv, "api/sca/management-of-risk/package-vulnerabilities"}, {KicsResultsPathKey, KicsResultsPathEnv, "api/kics-results"}, {KicsResultsPredicatesPathKey, KicsResultsPredicatesPathEnv, "api/kics-results-predicates"}, {ScsResultsReadPredicatesPathKey, ScsResultsReadPredicatesPathEnv, "api/micro-engines/read/predicates"}, diff --git a/internal/params/envs.go b/internal/params/envs.go index 9698eb699..56f484b70 100644 --- a/internal/params/envs.go +++ b/internal/params/envs.go @@ -31,6 +31,7 @@ const ( ScsScanOverviewPathEnv = "CX_SCS_SCAN_OVERVIEW_PATH" SastResultsPathEnv = "CX_SAST_RESULTS_PATH" SastResultsPredicatesPathEnv = "CX_SAST_RESULTS_PREDICATES_PATH" + ScaResultsPredicatesPathEnv = "CX_SCA_RESULTS_PREDICATES_PATH" KicsResultsPathEnv = "CX_KICS_RESULTS_PATH" KicsResultsPredicatesPathEnv = "CX_KICS_RESULTS_PREDICATES_PATH" ScsResultsReadPredicatesPathEnv = "CX_SCS_RESULTS_PREDICATES_READ_PATH" diff --git a/internal/params/flags.go b/internal/params/flags.go index 2eb507d52..b201078b6 100644 --- a/internal/params/flags.go +++ b/internal/params/flags.go @@ -128,7 +128,10 @@ const ( "Example: scan --threshold \"sast-high=10;sca-high=5;iac-security-low=10\"" KeyValuePairSize = 2 WaitDelayDefault = 5 + SingleValueSize = 1 + ChangeState = "ChangeState" SimilarityIDFlag = "similarity-id" + VulnerabilitiesFlag = "vulnerability-identifiers" SeverityFlag = "severity" StateFlag = "state" CustomStateIDFlag = "state-id" diff --git a/internal/params/keys.go b/internal/params/keys.go index adabc95a5..bcab73654 100644 --- a/internal/params/keys.go +++ b/internal/params/keys.go @@ -60,6 +60,7 @@ var ( LogsPathKey = strings.ToLower(LogsPathEnv) LogsEngineLogPathKey = strings.ToLower(LogsEngineLogPathEnv) SastResultsPredicatesPathKey = strings.ToLower(SastResultsPredicatesPathEnv) + ScaResultsPredicatesPathKey = strings.ToLower(ScaResultsPredicatesPathEnv) KicsResultsPredicatesPathKey = strings.ToLower(KicsResultsPredicatesPathEnv) ScsResultsReadPredicatesPathKey = strings.ToLower(ScsResultsReadPredicatesPathEnv) ScsResultsWritePredicatesPathKey = strings.ToLower(ScsResultsWritePredicatesPathEnv) diff --git a/internal/wrappers/mock/predicates-mock.go b/internal/wrappers/mock/predicates-mock.go index 894932ef7..9f8d7a513 100644 --- a/internal/wrappers/mock/predicates-mock.go +++ b/internal/wrappers/mock/predicates-mock.go @@ -10,7 +10,7 @@ import ( type ResultsPredicatesMockWrapper struct { } -func (r ResultsPredicatesMockWrapper) PredicateSeverityAndState(predicate *wrappers.PredicateRequest, scanType string) ( +func (r ResultsPredicatesMockWrapper) PredicateSeverityAndState(predicate interface{}, scanType string) ( *wrappers.WebError, error, ) { fmt.Println("Called 'PredicateSeverityAndState' in ResultsPredicatesMockWrapper") @@ -43,3 +43,8 @@ func (r ResultsPredicatesMockWrapper) GetAllPredicatesForSimilarityID(similarity }, }, nil, nil } + +func (r ResultsPredicatesMockWrapper) GetScaPredicates(vulnerabilityDetails []string, projectID string) (*wrappers.ScaPredicateResult, error) { + fmt.Println("Called 'GetScaPredicates' in ResultsPredicatesMockWrapper") + return nil, nil +} diff --git a/internal/wrappers/predicates-http.go b/internal/wrappers/predicates-http.go index 536856743..a6484ac64 100644 --- a/internal/wrappers/predicates-http.go +++ b/internal/wrappers/predicates-http.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "strings" @@ -26,6 +27,47 @@ func NewResultsPredicatesHTTPWrapper() ResultsPredicatesWrapper { return &ResultsPredicatesHTTPWrapper{} } +func (r *ResultsPredicatesHTTPWrapper) GetScaPredicates(vulnerabilityDetails []string, projectID string) (*ScaPredicateResult, error) { + clientTimeout := viper.GetUint(params.ClientTimeoutKey) + r.SetPath(viper.GetString(params.ScaResultsPredicatesPathEnv)) + var request = "/entity-profile/search" + logger.PrintIfVerbose(fmt.Sprintf("Sending POST request to %s", r.path+request)) + + scaPredicateRequest := make(map[string]interface{}) + for _, vulnerability := range vulnerabilityDetails { + vulnerabilityKeyVal := strings.Split(vulnerability, "=") + if len(vulnerabilityKeyVal) != params.KeyValuePairSize { + return nil, errors.Errorf("Invalid vulnerability details format: %s", vulnerability) + } + scaPredicateRequest[strings.TrimSpace(vulnerabilityKeyVal[0])] = strings.TrimSpace(vulnerabilityKeyVal[1]) + } + scaPredicateRequest["projectId"] = projectID + scaPredicateRequest["actionType"] = params.ChangeState + jsonBody, err := json.Marshal(scaPredicateRequest) + if err != nil { + return nil, errors.Wrap(err, "Failed to marshal request") + } + resp, err := SendHTTPRequestWithJSONContentType(http.MethodPost, r.path+request, bytes.NewBuffer(jsonBody), true, clientTimeout) + if err != nil { + return nil, errors.Wrap(err, "Failed to send request") + } + defer func() { + if err == nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("Failed to get SCA predicate result.") + } + decoder := json.NewDecoder(resp.Body) + var scaPredicates ScaPredicateResult + err = decoder.Decode(&scaPredicates) + if err != nil { + return nil, errors.Wrap(err, "Failed to decode response") + } + return &scaPredicates, nil +} + func (r *ResultsPredicatesHTTPWrapper) GetAllPredicatesForSimilarityID(similarityID, projectID, scannerType string) ( *PredicatesCollectionResponseModel, *WebError, error, ) { @@ -65,25 +107,20 @@ func (r *ResultsPredicatesHTTPWrapper) SetPath(newPath string) { r.path = newPath } -func (r ResultsPredicatesHTTPWrapper) PredicateSeverityAndState(predicate *PredicateRequest, scanType string) ( +func (r ResultsPredicatesHTTPWrapper) PredicateSeverityAndState(predicate interface{}, scanType string) ( *WebError, error, ) { clientTimeout := viper.GetUint(params.ClientTimeoutKey) - b := [...]PredicateRequest{*predicate} - jsonBytes, err := json.Marshal(b) + + predicateModel := preparePredicateModel(predicate, scanType) + jsonBytes, err := json.Marshal(predicateModel) if err != nil { return nil, err } - var triageAPIPath string - if strings.EqualFold(strings.TrimSpace(scanType), params.SastType) { - triageAPIPath = viper.GetString(params.SastResultsPredicatesPathKey) - } else if strings.EqualFold(strings.TrimSpace(scanType), params.KicsType) || strings.EqualFold(strings.TrimSpace(scanType), params.IacType) { - triageAPIPath = viper.GetString(params.KicsResultsPredicatesPathKey) - } else if strings.EqualFold(strings.TrimSpace(scanType), params.ScsType) { - triageAPIPath = viper.GetString(params.ScsResultsWritePredicatesPathKey) - } else { - return nil, errors.Errorf(invalidScanType, scanType) + triageAPIPath, err := getTriageAPIPath(scanType) + if err != nil { + return nil, err } logger.PrintIfVerbose(fmt.Sprintf("Sending POST request to %s", triageAPIPath)) @@ -91,7 +128,7 @@ func (r ResultsPredicatesHTTPWrapper) PredicateSeverityAndState(predicate *Predi r.SetPath(triageAPIPath) - resp, err := SendHTTPRequest(http.MethodPost, r.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) + resp, err := SendHTTPRequestWithJSONContentType(http.MethodPost, r.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) if err != nil { return nil, err } @@ -102,36 +139,71 @@ func (r ResultsPredicatesHTTPWrapper) PredicateSeverityAndState(predicate *Predi _ = resp.Body.Close() }() - // in case of ne/pne when mandatory comment arent provided, cli is not transforming error message + if err := checkMandatoryCommentError(resp.Body, scanType); err != nil { + return nil, err + } + + return nil, handlePredicateStatusCode(resp.StatusCode) +} + +func preparePredicateModel(predicate interface{}, scanType string) interface{} { + if !strings.EqualFold(strings.TrimSpace(scanType), params.ScaType) { + return []interface{}{predicate} + } + return predicate +} + +func getTriageAPIPath(scanType string) (string, error) { + ScanType := strings.ToLower(strings.TrimSpace(scanType)) + + switch ScanType { + case strings.ToLower(params.SastType): + return viper.GetString(params.SastResultsPredicatesPathKey), nil + case strings.ToLower(params.KicsType), strings.ToLower(params.IacType): + return viper.GetString(params.KicsResultsPredicatesPathKey), nil + case strings.ToLower(params.ScsType): + return viper.GetString(params.ScsResultsWritePredicatesPathKey), nil + case strings.ToLower(params.ScaType): + return viper.GetString(params.ScaResultsPredicatesPathEnv), nil + default: + return "", errors.Errorf(invalidScanType, scanType) + } +} + +func checkMandatoryCommentError(body io.ReadCloser, scanType string) error { responseMap := make(map[string]interface{}) - if err := json.NewDecoder(resp.Body).Decode(&responseMap); err != nil { - logger.PrintIfVerbose(fmt.Sprintf("failed to read the response, %v", err.Error())) - } else { - if val, ok := responseMap["code"].(float64); ok { - if val == 4002 && responseMap["message"] != nil { - if errMsg, ok := responseMap["message"].(string); ok { - if errMsg == "A comment is required to make changes to the result state" { - return nil, errors.Errorf(errMsg) - } - } + if err := json.NewDecoder(body).Decode(&responseMap); err != nil { + if scanType != params.ScaType { + logger.PrintIfVerbose(fmt.Sprintf("failed to read the response, %v", err.Error())) + } + return nil + } + + if val, ok := responseMap["code"].(float64); ok && val == 4002 { + if errMsg, ok := responseMap["message"].(string); ok { + if errMsg == "A comment is required to make changes to the result state" { + return errors.Errorf(errMsg) } } } + return nil +} - switch resp.StatusCode { - case http.StatusBadRequest, http.StatusInternalServerError: - return nil, errors.Errorf("Predicate bad request.") +func handlePredicateStatusCode(statusCode int) error { + switch statusCode { case http.StatusOK, http.StatusCreated: fmt.Println("Predicate updated successfully.") - return nil, nil + return nil case http.StatusNotModified: - return nil, errors.Errorf("No changes to update.") + return errors.Errorf("No changes to update.") case http.StatusForbidden: - return nil, errors.Errorf("No permission to update predicate.") + return errors.Errorf("No permission to update predicate.") case http.StatusNotFound: - return nil, errors.Errorf("Predicate not found.") + return errors.Errorf("Predicate not found.") + case http.StatusBadRequest, http.StatusInternalServerError: + return errors.Errorf("Predicate bad request.") default: - return nil, errors.Errorf("response status code %d", resp.StatusCode) + return errors.Errorf("response status code %d", statusCode) } } diff --git a/internal/wrappers/predicates.go b/internal/wrappers/predicates.go index 709e5b3be..de719fa83 100644 --- a/internal/wrappers/predicates.go +++ b/internal/wrappers/predicates.go @@ -22,6 +22,31 @@ type PredicateRequest struct { Severity string `json:"severity"` } +type ScaPredicateRequest struct { + PackageName string `json:"packageName"` + PackageVersion string `json:"packageVersion"` + PackageManager string `json:"packageManager"` + VulnerabilityID string `json:"vulnerabilityId"` + ProjectIds []string `json:"projectIds"` + Actions []ScaAction `json:"actions"` +} + +type State string + +const ( + ToVerify string = "ToVerify" + Confirmed string = "Confirmed" + NotExploitable string = "NotExploitable" + ProposedNotExploitable string = "ProposedNotExploitable" + Urgent string = "Urgent" +) + +type ScaAction struct { + ActionType string `json:"actionType"` + Value string `json:"value"` + Comment string `json:"comment"` +} + type Predicate struct { BasePredicate ID string `json:"ID"` @@ -37,10 +62,37 @@ type PredicateHistory struct { } type PredicatesCollectionResponseModel struct { + ScaResponse interface{} `json:"scaPredicate,omitempty"` PredicateHistoryPerProject []PredicateHistory `json:"predicateHistoryPerProject"` TotalCount int `json:"totalCount"` } +type ScaPredicateResult struct { + ID string `json:"id"` + Context Context `json:"context"` + Name string `json:"name"` + Actions []Action `json:"actions"` + EntityType string `json:"entityType"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"createdAt"` +} + +type Context struct { + PackageManager string `json:"PackageManager"` + PackageName string `json:"PackageName"` + PackageVersion string `json:"PackageVersion"` + VulnerabilityID string `json:"VulnerabilityId"` +} + +type Action struct { + ActionType string `json:"actionType"` + ActionValue string `json:"actionValue"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"createdAt"` + UserName string `json:"userName"` + Message string `json:"message"` +} + type CustomState struct { ID int `json:"id"` Name string `json:"name"` @@ -52,7 +104,8 @@ type CustomStatesWrapper interface { } type ResultsPredicatesWrapper interface { - PredicateSeverityAndState(predicate *PredicateRequest, scanType string) (*WebError, error) + GetScaPredicates(vulnerabilityDetails []string, projectID string) (*ScaPredicateResult, error) + PredicateSeverityAndState(predicate interface{}, scanType string) (*WebError, error) GetAllPredicatesForSimilarityID( similarityID string, projectID string, scannerType string, ) (*PredicatesCollectionResponseModel, *WebError, error) diff --git a/test/integration/predicate_test.go b/test/integration/predicate_test.go index 5bd27d029..1e925e6ed 100644 --- a/test/integration/predicate_test.go +++ b/test/integration/predicate_test.go @@ -292,3 +292,50 @@ func TestTriageShowAndUpdateWithCustomStates(t *testing.T) { assert.Assert(t, found, "Updated predicate should have state set to state2") } + +func TestScaUpdateWithVulnerabilityDetails(t *testing.T) { + + fmt.Println("Testing the command 'triage update' with scan-type sca using vulnerability-details.") + + _, projectID := getRootScan(t) + + // Hardcoded vulnerability details for testing SCA triage + packageName := "Maven-org.apache.tomcat.embed:tomcat-embed-core" + packageVersion := "9.0.14" + vulnerabilityID := "CVE-2024-56337" + packageManager := "maven" + state := "NOT_EXPLOITABLE" + comment := "Testing CLI Command for triage with SCA scan type." + + args := []string{ + "triage", "update", + flag(params.ProjectIDFlag), projectID, + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packagename=%s", packageName), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packageversion=%s", packageVersion), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("vulnerabilityId=%s", vulnerabilityID), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packageManager=%s", packageManager), + flag(params.StateFlag), state, + flag(params.CommentFlag), comment, + flag(params.ScanTypeFlag), params.ScaType, + } + + err, outputBufferForStep1 := executeCommand(t, args...) + _, readingError := io.ReadAll(outputBufferForStep1) + assert.NilError(t, readingError, "Reading result should pass") + + assert.NilError(t, err, "Updating the SCA predicate with vulnerability-details should pass.") + + fmt.Println("Testing the command 'triage show' with scan-type sca to verify the update.") + outputBufferForStep2 := executeCmdNilAssertion( + t, "SCA Predicates should be fetched.", "triage", "show", + flag(params.FormatFlag), printer.FormatJSON, + flag(params.ProjectIDFlag), projectID, + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packagename=%s", packageName), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packageversion=%s", packageVersion), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("vulnerabilityId=%s", vulnerabilityID), + flag(params.VulnerabilitiesFlag), fmt.Sprintf("packageManager=%s", packageManager), + flag(params.ScanTypeFlag), params.ScaType, + ) + + fmt.Println(outputBufferForStep2) +}