Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(amazonq): for /test adding backoff and retry for payload upload APIs. #5310

Merged
merged 10 commits into from
Feb 3, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import software.amazon.awssdk.core.exception.SdkServiceException
import software.amazon.awssdk.services.codewhispererruntime.model.GetTestGenerationResponse
import software.amazon.awssdk.services.codewhispererruntime.model.Range
import software.amazon.awssdk.services.codewhispererruntime.model.StartTestGenerationResponse
import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.TestGenerationJobStatus
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.aws.toolkits.core.utils.Waiters.waitUntil
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.error
import software.aws.toolkits.core.utils.getLogger
Expand Down Expand Up @@ -58,6 +58,7 @@ import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
import java.nio.file.Paths
import java.time.Duration
import java.time.Instant
import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.ZipInputStream
Expand Down Expand Up @@ -109,29 +110,38 @@ class CodeWhispererUTGChatManager(val project: Project, private val cs: Coroutin

// 2nd API call: StartTestGeneration
val startTestGenerationResponse = try {
startTestGeneration(
uploadId = createUploadUrlResponse.uploadId(),
targetCode = listOf(
TargetCode.builder()
.relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
.targetLineRangeList(
if (selectionRange != null) {
listOf(
selectionRange
var response: StartTestGenerationResponse? = null

waitUntil(
succeedOn = { response?.sdkHttpResponse()?.statusCode() == 200 },
maxDuration = Duration.ofSeconds(1), // 1 second timeout
) {
try {
response = startTestGeneration(
uploadId = createUploadUrlResponse.uploadId(),
targetCode = listOf(
TargetCode.builder()
.relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
.targetLineRangeList(
if (selectionRange != null) {
listOf(selectionRange)
} else {
emptyList()
}
)
} else {
emptyList()
}
)
.build()
),
userInput = prompt
)
} catch (e: Exception) {
val statusCode = when {
e is SdkServiceException -> e.statusCode()
else -> 400
.build()
),
userInput = prompt
)
delay(200)
response?.testGenerationJob() != null
} catch (e: Exception) {
throw e
}
}

response ?: throw RuntimeException("Failed to start test generation")
} catch (e: Exception) {
LOG.error(e) { "Unexpected error while creating test generation job" }
val errorMessage = getTelemetryErrorMessage(e, CodeWhispererConstants.FeatureName.TEST_GENERATION)
throw CodeTestException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisUp
import software.amazon.awssdk.services.codewhispererruntime.model.CodeFixUploadContext
import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlRequest
import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlResponse
import software.amazon.awssdk.services.codewhispererruntime.model.InternalServerException
import software.amazon.awssdk.services.codewhispererruntime.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererruntime.model.UploadContext
import software.amazon.awssdk.services.codewhispererruntime.model.UploadIntent
import software.amazon.awssdk.utils.IoUtils
Expand Down Expand Up @@ -82,40 +84,50 @@ class CodeWhispererZipUploadManager(private val project: Project) {
requestHeaders: Map<String, String>?,
featureUseCase: CodeWhispererConstants.FeatureName,
) {
try {
val uploadIdJson = """{"uploadId":"$uploadId"}"""
HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner {
if (requestHeaders.isNullOrEmpty()) {
it.setRequestProperty(CONTENT_MD5, md5)
it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP)
it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS)
if (kmsArn?.isNotEmpty() == true) {
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn)
}
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray()))
} else {
requestHeaders.forEach { entry ->
it.setRequestProperty(entry.key, entry.value)
RetryableOperation<Unit>().execute(
operation = {
val uploadIdJson = """{"uploadId":"$uploadId"}"""
HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner {
if (requestHeaders.isNullOrEmpty()) {
it.setRequestProperty(CONTENT_MD5, md5)
it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP)
it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS)
if (kmsArn?.isNotEmpty() == true) {
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn)
}
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray()))
} else {
requestHeaders.forEach { entry ->
it.setRequestProperty(entry.key, entry.value)
}
}
}.connect {
val connection = it.connection as HttpURLConnection
connection.setFixedLengthStreamingMode(fileToUpload.length())
IoUtils.copy(fileToUpload.inputStream(), connection.outputStream)
}
},
isRetryable = { e ->
when (e) {
is IOException -> true
else -> false
}
},
errorHandler = { e, attempts ->
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW ->
codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION ->
throw CodeTestException(
"UploadTestArtifactToS3Error: $errorMessage",
"UploadTestArtifactToS3Error",
message("testgen.error.generic_technical_error_message")
)
else -> throw RuntimeException("$errorMessage (after $attempts attempts)")
}
}.connect {
val connection = it.connection as HttpURLConnection
connection.setFixedLengthStreamingMode(fileToUpload.length())
IoUtils.copy(fileToUpload.inputStream(), connection.outputStream)
}
} catch (e: Exception) {
LOG.debug { "$featureUseCase: Artifact failed to upload in the S3 bucket: ${e.message}" }
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException(
"UploadTestArtifactToS3Error: $errorMessage",
"UploadTestArtifactToS3Error",
message("testgen.error.generic_technical_error_message")
)
else -> throw RuntimeException(errorMessage) // Adding else for safety check
}
}
)
}

fun createUploadUrl(
Expand All @@ -124,35 +136,44 @@ class CodeWhispererZipUploadManager(private val project: Project) {
uploadTaskType: CodeWhispererConstants.UploadTaskType,
taskName: String,
featureUseCase: CodeWhispererConstants.FeatureName,
): CreateUploadUrlResponse = try {
CodeWhispererClientAdaptor.getInstance(project).createUploadUrl(
CreateUploadUrlRequest.builder()
.contentMd5(md5Content)
.artifactType(artifactType)
.uploadIntent(getUploadIntent(uploadTaskType))
.uploadContext(
// For UTG we don't need uploadContext but sending else case as UploadContext
if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) {
UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build())
} else {
UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build())
}
)
.build()
)
} catch (e: Exception) {
LOG.debug { "$featureUseCase: Create Upload URL failed: ${e.message}" }
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException(
"CreateUploadUrlError: $errorMessage",
"CreateUploadUrlError",
message("testgen.error.generic_technical_error_message")
): CreateUploadUrlResponse = RetryableOperation<CreateUploadUrlResponse>().execute(
operation = {
CodeWhispererClientAdaptor.getInstance(project).createUploadUrl(
CreateUploadUrlRequest.builder()
.contentMd5(md5Content)
.artifactType(artifactType)
.uploadIntent(getUploadIntent(uploadTaskType))
.uploadContext(
// For UTG we don't need uploadContext but sending else case as UploadContext
if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) {
UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build())
} else {
UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build())
}
)
.build()
)
else -> throw RuntimeException(errorMessage) // Adding else for safety check
},
isRetryable = { e ->
e is ThrottlingException || e is InternalServerException
},
errorHandler = { e, attempts ->
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW ->
codeScanServerException("CreateUploadUrlException after $attempts attempts: $errorMessage")

CodeWhispererConstants.FeatureName.TEST_GENERATION ->
throw CodeTestException(
"CreateUploadUrlError after $attempts attempts: $errorMessage",
"CreateUploadUrlError",
message("testgen.error.generic_technical_error_message")
)

else -> throw RuntimeException("$errorMessage (after $attempts attempts)")
}
}
}
)

private fun getUploadIntent(uploadTaskType: CodeWhispererConstants.UploadTaskType): UploadIntent = when (uploadTaskType) {
CodeWhispererConstants.UploadTaskType.SCAN_FILE -> UploadIntent.AUTOMATIC_FILE_SECURITY_SCAN
Expand Down Expand Up @@ -187,3 +208,41 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant
else -> message("testgen.message.failed")
}
}

class RetryableOperation<T> {
private var attempts = 0
private var currentDelay = INITIAL_DELAY
private var lastException: Exception? = null

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean,
errorHandler: (Exception, Int) -> Nothing,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
lastException = e

attempts++
if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) {
Thread.sleep(currentDelay)
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
continue
}

errorHandler(e, attempts)
}
}

// This line should never be reached due to errorHandler throwing exception
throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3
}
}
Loading