Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import nextflow.BuildInfo
* Paolo Di Tommaso <[email protected]>
*/
@Slf4j
@Deprecated
@CompileStatic
class SimpleHttpClient {

Expand Down
10 changes: 5 additions & 5 deletions modules/nf-commons/src/main/nextflow/util/RetryConfig.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ import nextflow.SysEnv
@CompileStatic
class RetryConfig implements Retryable.Config {

private final static Duration DEFAULT_DELAY = Duration.of('350ms')
private final static Duration DEFAULT_MAX_DELAY = Duration.of('90s')
private final static Integer DEFAULT_MAX_ATTEMPTS = 5
private final static Double DEFAULT_JITTER = 0.25
static final public double DEFAULT_MULTIPLIER = 2.0
public final static Duration DEFAULT_DELAY = Duration.of('350ms')
public final static Duration DEFAULT_MAX_DELAY = Duration.of('90s')
public final static Integer DEFAULT_MAX_ATTEMPTS = 5
public final static Double DEFAULT_JITTER = 0.25
public final static double DEFAULT_MULTIPLIER = 2.0

private final static String ENV_PREFIX = 'NXF_RETRY_POLICY_'

Expand Down
1 change: 1 addition & 0 deletions plugins/nf-tower/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies {
compileOnly 'org.slf4j:slf4j-api:2.0.17'
compileOnly 'org.pf4j:pf4j:3.12.0'

api 'io.seqera:lib-httpx:1.6.0'
api "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.15.0"
api "com.fasterxml.jackson.core:jackson-databind:2.12.7.1"

Expand Down
168 changes: 73 additions & 95 deletions plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package io.seqera.tower.plugin

import java.net.http.HttpClient
import java.net.http.HttpRequest

import java.time.Instant
import java.time.OffsetDateTime
Expand All @@ -29,14 +31,15 @@ import groovy.json.JsonGenerator
import groovy.json.JsonOutput
import groovy.json.JsonSlurper
import groovy.transform.CompileStatic
import groovy.transform.Memoized
import groovy.transform.ToString
import groovy.transform.TupleConstructor
import groovy.util.logging.Slf4j
import io.seqera.http.HxClient
import io.seqera.http.HxConfig
import io.seqera.util.trace.TraceUtils
import nextflow.BuildInfo
import nextflow.Session
import nextflow.container.resolver.ContainerMeta
import nextflow.container.resolver.ContainerResolver
import nextflow.container.resolver.ContainerResolverProvider
import nextflow.exception.AbortOperationException
import nextflow.processor.TaskHandler
import nextflow.processor.TaskId
Expand All @@ -49,7 +52,6 @@ import nextflow.trace.event.TaskEvent
import nextflow.util.Duration
import nextflow.util.LoggerHelper
import nextflow.util.ProcessHelper
import nextflow.util.SimpleHttpClient
import nextflow.util.TestOnly
import nextflow.util.Threads
/**
Expand Down Expand Up @@ -98,10 +100,7 @@ class TowerClient implements TraceObserverV2 {
*/
private String runId

/**
* Simple http client object that will send out messages
*/
private SimpleHttpClient httpClient
private HxClient httpClient

private JsonGenerator generator

Expand Down Expand Up @@ -139,19 +138,20 @@ class TowerClient implements TraceObserverV2 {

private String accessToken

private String refreshToken

private String workspaceId

private TowerReports reports

private TowerRetryPolicy retryPolicy

private Map<String,Boolean> allContainers = new ConcurrentHashMap<>()

TowerClient(Session session, TowerConfig config) {
this.session = session
this.endpoint = checkUrl(config.endpoint)
this.accessToken = config.accessToken
this.workspaceId = config.workspaceId
this.retryPolicy = config.retryPolicy
this.schema = loadSchema()
this.generator = TowerJsonGenerator.create(schema)
this.reports = new TowerReports(session)
Expand Down Expand Up @@ -278,9 +278,7 @@ class TowerClient implements TraceObserverV2 {
this.aggregator = new ResourcesAggregator(session)
this.runName = session.getRunName()
this.runId = session.getUniqueId()
this.httpClient = new SimpleHttpClient()
// set the auth token
setAuthToken( httpClient, getAccessToken() )
this.httpClient = newHttpClient()

// send hello to verify auth
final req = makeCreateReq(session)
Expand All @@ -305,10 +303,28 @@ class TowerClient implements TraceObserverV2 {
reports.flowCreate(workflowId)
}

protected void setAuthToken(SimpleHttpClient client, String token) {
protected HxClient newHttpClient() {
final config = new HxConfig.Builder()
// auth settings
setupClientAuth(config, getAccessToken())
// retry settings
config.withRetryConfig(this.retryPolicy)
// create the client object
final client = HttpClient
.newBuilder()
.followRedirects(HttpClient.Redirect.NORMAL)
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(java.time.Duration.ofSeconds(60))
.build()
return HxClient.create(client, config.build())
}

protected void setupClientAuth(HxConfig.Builder config, String token) {
// check for plain jwt token
if( token.count('.')==2 ) {
client.setBearerToken(token)
config.withJwtToken(token)
config.withRefreshToken(env.get('TOWER_REFRESH_TOKEN'))
config.withRefreshTokenUrl("$endpoint/oauth/access_token")
return
}

Expand All @@ -318,7 +334,10 @@ class TowerClient implements TraceObserverV2 {
final p = plain.indexOf('.')
if( p!=-1 && new JsonSlurper().parseText( plain.substring(0, p) ) ) {
// ok this is bearer token
client.setBearerToken(token)
config.withJwtToken(token)
// setup the refresh
config.withRefreshToken(env.get('TOWER_REFRESH_TOKEN'))
config.withRefreshTokenUrl("$endpoint/oauth/access_token")
return
}
}
Expand All @@ -327,7 +346,7 @@ class TowerClient implements TraceObserverV2 {
}

// fallback on simple token
client.setBasicToken(TOKEN_PREFIX + token)
config.withBasicAuth(TOKEN_PREFIX + token)
}

protected Map makeCreateReq(Session session) {
Expand All @@ -352,9 +371,6 @@ class TowerClient implements TraceObserverV2 {
@Override
void onFlowBegin() {
// configure error retry
httpClient.maxRetries = maxRetries
httpClient.backOffBase = backOffBase
httpClient.backOffDelay = backOffDelay

final req = makeBeginReq(session)
final resp = sendHttpMessage(urlTraceBegin, req, 'PUT')
Expand Down Expand Up @@ -479,37 +495,6 @@ class TowerClient implements TraceObserverV2 {
reports.filePublish(event.target)
}

protected void refreshToken(String refresh) {
log.debug "Token refresh request >> $refresh"
final url = "$endpoint/oauth/access_token"
httpClient.sendHttpMessage(
url,
method: 'POST',
contentType: "application/x-www-form-urlencoded",
body: "grant_type=refresh_token&refresh_token=${URLEncoder.encode(refresh, 'UTF-8')}" )

final authCookie = httpClient.getCookie('JWT')
final refreshCookie = httpClient.getCookie('JWT_REFRESH_TOKEN')

// set the new bearer token
if( authCookie?.value ) {
log.trace "Updating http client bearer token=$authCookie.value"
httpClient.setBearerToken(authCookie.value)
}
else {
log.warn "Missing JWT cookie from refresh token response ~ $authCookie"
}

// set the new refresh token
if( refreshCookie?.value ) {
log.trace "Updating http client refresh token=$refreshCookie.value"
refreshToken = refreshCookie.value
}
else {
log.warn "Missing JWT_REFRESH_TOKEN cookie from refresh token response ~ $refreshCookie"
}
}

/**
* Little helper method that sends a HTTP POST message as JSON with
* the current run status, ISO 8601 UTC timestamp, run name and the TraceRecord
Expand All @@ -520,51 +505,48 @@ class TowerClient implements TraceObserverV2 {
*/
protected Response sendHttpMessage(String url, Map payload, String method='POST') {

int refreshTries=0
final currentRefresh = refreshToken ?: env.get('TOWER_REFRESH_TOKEN')

while ( true ) {
// The actual HTTP request
final String json = payload != null ? generator.toJson(payload) : null
final String debug = json != null ? JsonOutput.prettyPrint(json).indent() : '-'
log.trace "HTTP url=$url; payload:\n${debug}\n"
try {
if( refreshTries==1 ) {
refreshToken(currentRefresh)
}

httpClient.sendHttpMessage(url, json, method)
return new Response(httpClient.responseCode, httpClient.getResponse())
// The actual HTTP request
final String json = payload != null ? generator.toJson(payload) : null
final String debug = json != null ? JsonOutput.prettyPrint(json).indent() : '-'
log.trace "HTTP url=$url; payload:\n${debug}\n"
try {
final resp = httpClient.sendAsString(makeRequest(url, json, method))
final status = resp.statusCode()
if( status == 401 ) {
final msg = 'Unauthorized Seqera Platform API access -- Make sure you have specified the correct access token'
return new Response(status, msg)
}
catch( ConnectException e ) {
String msg = "Unable to connect to Seqera Platform API: ${getHostUrl(url)}"
return new Response(0, msg)
}
catch (IOException e) {
int code = httpClient.responseCode
if( code == 401 && ++refreshTries==1 && currentRefresh ) {
// when 401 Unauthorized error is returned - only the very first time -
// and a refresh token is available, make another iteration trying
// having refreshed the authorization token (see 'refreshToken' invocation above)
log.trace "Got 401 Unauthorized response ~ tries refreshing auth token"
continue
}
else {
log.trace("Got HTTP code $code - refreshTries=$refreshTries - currentRefresh=$currentRefresh", e)
}

String msg
if( code == 401 ) {
msg = 'Unauthorized Seqera Platform API access -- Make sure you have specified the correct access token'
}
else {
msg = parseCause(httpClient.response) ?: "Unexpected response for request $url"
}
return new Response(code, msg, httpClient.response)
if( status>=400 ) {
final msg = parseCause(resp?.body()) ?: "Unexpected response for request $url"
return new Response(status, msg as String)
}
else
return new Response(status, resp.body())
}
catch( IOException e ) {
String msg = "Unable to connect to Seqera Platform API: ${getHostUrl(url)}"
return new Response(0, msg)
}
}

protected HttpRequest makeRequest(String url, String payload, String verb) {
assert payload, "Tower request cannot be empty"

final builder = HttpRequest.newBuilder(URI.create(url))
.header('Content-Type', 'application/json; charset=utf-8')
.header('User-Agent', "Nextflow/$BuildInfo.version")
.header('Traceparent', TraceUtils.rndTrace())

if(verb == 'PUT')
return builder.PUT(HttpRequest.BodyPublishers.ofString(payload)).build()

if(verb == 'POST')
return builder.POST(HttpRequest.BodyPublishers.ofString(payload)).build()

else
throw new IllegalArgumentException("Unsupported HTTP verb: $verb")
}

protected boolean isCliLogsEnabled() {
return env.get('TOWER_ALLOW_NEXTFLOW_LOGS') == 'true'
}
Expand Down Expand Up @@ -847,8 +829,4 @@ class TowerClient implements TraceObserverV2 {
}
}

@Memoized
private ContainerResolver containerResolver() {
ContainerResolverProvider.load()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TowerConfig implements ConfigScope {
""")
final String workspaceId

final TowerRetryPolicy retryPolicy

/* required by extension point -- do not remove */
TowerConfig() {}

Expand All @@ -68,5 +70,6 @@ class TowerConfig implements ConfigScope {
this.enabled = opts.enabled as boolean
this.endpoint = PlatformHelper.getEndpoint(opts, env)
this.workspaceId = PlatformHelper.getWorkspaceId(opts, env)
this.retryPolicy = new TowerRetryPolicy(opts.retryPolicy as Map ?: Map.of(), opts)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ import nextflow.Session
import nextflow.SysEnv
import nextflow.file.http.XAuthProvider
import nextflow.file.http.XAuthRegistry
import nextflow.trace.TraceObserverV2
import nextflow.trace.TraceObserverFactoryV2
import nextflow.trace.TraceObserverV2
import nextflow.util.Duration
import nextflow.util.SimpleHttpClient
/**
* Create and register the Tower observer instance
*
Expand Down Expand Up @@ -70,10 +69,6 @@ class TowerFactory implements TraceObserverFactoryV2 {
tower.aliveInterval = aliveInterval
if( requestInterval )
tower.requestInterval = requestInterval
// error handling settings
tower.maxRetries = opts.maxRetries != null ? opts.maxRetries as int : 5
tower.backOffBase = opts.backOffBase != null ? opts.backOffBase as int : SimpleHttpClient.DEFAULT_BACK_OFF_BASE
tower.backOffDelay = opts.backOffDelay != null ? opts.backOffDelay as int : SimpleHttpClient.DEFAULT_BACK_OFF_DELAY

// register auth provider
// note: this is needed to authorize access to resources via XFileSystemProvider used by NF
Expand Down
Loading