diff --git a/modules/nextflow/src/main/groovy/nextflow/processor/TaskHandler.groovy b/modules/nextflow/src/main/groovy/nextflow/processor/TaskHandler.groovy index a0f980fe7a..7babacf059 100644 --- a/modules/nextflow/src/main/groovy/nextflow/processor/TaskHandler.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/processor/TaskHandler.groovy @@ -16,6 +16,10 @@ package nextflow.processor +import nextflow.Global +import nextflow.Session +import nextflow.fusion.FusionHelper + import static nextflow.processor.TaskStatus.* import java.nio.file.NoSuchFileException @@ -249,6 +253,9 @@ abstract class TaskHandler { catch( IOException e ) { log.debug "[WARN] Cannot read trace file: $file -- Cause: ${e.message}" } + // If Fusion is enabled read parse the use of accelerator form .command.log + if( Global.session && FusionHelper.isFusionEnabled(Global.session as Session) ) + record.parseFusionAcceleratorUsage(task.workDir?.resolve(TaskRun.CMD_LOG)) } return record diff --git a/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy b/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy index 361eb53a57..ecc2865693 100644 --- a/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy @@ -122,6 +122,7 @@ class TraceRecord implements Serializable { transient private CloudMachineInfo machineInfo transient private ContainerMeta containerMeta transient private Integer numSpotInterruptions + transient private Boolean acceleratorUsage /** * Convert the given value to a string @@ -627,4 +628,48 @@ class TraceRecord implements Serializable { void setContainerMeta(ContainerMeta meta) { this.containerMeta = meta } + + Boolean getAcceleratorUsage() { + return acceleratorUsage + } + + void setAcceleratorUsage(Boolean acc) { + this.acceleratorUsage = acc + } + + void parseFusionAcceleratorUsage(Path file) { + this.acceleratorUsage = parseFusionAccelerator0(file) + } + + /** + * Parses the Fusion accelerator value. + * Fusion writes FUSION_GPU_USED=true|false in the first line of the log file. + */ + private Boolean parseFusionAccelerator0(Path file) { + if ( !file.exists() ) { + return null + } + + String line = file.withReader { it.readLine() } + + if (!line) { + return null + } + + line = line.trim() + + if (!line.startsWith("FUSION_GPU_USED=")) { + return null + } + + String value = line.substring("FUSION_GPU_USED=".length()).trim() + + if (!value.equalsIgnoreCase("true") && !value.equalsIgnoreCase("false")) { + return null + } + + return value.toBoolean() + } + + } diff --git a/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy b/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy index f827747a42..a921f8974a 100644 --- a/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy +++ b/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy @@ -369,4 +369,78 @@ class TraceRecordTest extends Specification { rec2.getNumSpotInterruptions() == null } + def 'should manage accelerator field and not persist it across serialization'() { + given: + def rec = new TraceRecord() + + expect: + rec.getAcceleratorUsage() == null + and: + rec.acceleratorUsage == null + + when: + rec.setAcceleratorUsage(true) + + then: + rec.getAcceleratorUsage() == true + rec.acceleratorUsage == true + + when: + rec.setAcceleratorUsage(false) + + then: + rec.getAcceleratorUsage() == false + rec.acceleratorUsage == false + + when: + def buf = rec.serialize() + def rec2 = TraceRecord.deserialize(buf) + + then: + rec2.getAcceleratorUsage() == null + } + + @Unroll + def 'should parse fusion accelerator from file'() { + given: + def rec = new TraceRecord() + def file = TestHelper.createInMemTempFile('fusion-log') + file.text = CONTENT + + when: + rec.parseFusionAcceleratorUsage(file) + + then: + rec.getAcceleratorUsage() == EXPECTED + + where: + CONTENT | EXPECTED + 'FUSION_GPU_USED=true\n' | true + 'FUSION_GPU_USED=false\n' | false + 'FUSION_GPU_USED=TRUE\n' | true + 'FUSION_GPU_USED=FALSE\n' | false + ' FUSION_GPU_USED=true \n' | true + ' FUSION_GPU_USED=false \n' | false + 'FUSION_GPU_USED=true \nother line' | true + 'FUSION_GPU_USED=false\nother line' | false + 'other content\n' | null + 'FUSION_GPU=true\n' | null + 'FUSION_GPU_USED=\n' | null + 'FUSION_GPU_USED=invalid\n' | null + 'FUSION_GPU_USED=123\n' | null + '' | null + } + + def 'should parse fusion accelerator when file does not exist'() { + given: + def rec = new TraceRecord() + def file = Path.of('/non/existent/file.log') + + when: + rec.parseFusionAcceleratorUsage(file) + + then: + rec.getAcceleratorUsage() == null + } + } diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy index 4ffcc150d5..e64826d9b6 100644 --- a/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy @@ -661,6 +661,7 @@ class TowerClient implements TraceObserverV2 { record.machineType = trace.getMachineInfo()?.type record.priceModel = trace.getMachineInfo()?.priceModel?.toString() record.numSpotInterruptions = trace.getNumSpotInterruptions() + record.acceleratorUsage = trace.getAcceleratorUsage() return record } diff --git a/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy index c133e3d897..6fb319a5b3 100644 --- a/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy +++ b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy @@ -560,4 +560,29 @@ class TowerClientTest extends Specification { req.tasks[0].numSpotInterruptions == 3 } + def 'should include accelerator in task map'() { + given: + def client = Spy(new TowerClient()) + client.getWorkflowProgress(true) >> new WorkflowProgress() + + def now = System.currentTimeMillis() + def trace = new TraceRecord([ + taskId: 42, + process: 'foo', + workdir: "/work/dir", + cpus: 1, + submit: now-2000, + start: now-1000, + complete: now + ]) + trace.setAcceleratorUsage(true) + + when: + def req = client.makeTasksReq([trace]) + + then: + req.tasks.size() == 1 + req.tasks[0].acceleratorUsage == true + } + }