diff --git a/src/main/java/com/iexec/worker/compute/app/AppComputeService.java b/src/main/java/com/iexec/worker/compute/app/AppComputeService.java index 8b0374442..2ecc05051 100644 --- a/src/main/java/com/iexec/worker/compute/app/AppComputeService.java +++ b/src/main/java/com/iexec/worker/compute/app/AppComputeService.java @@ -23,12 +23,10 @@ import com.iexec.commons.containers.DockerRunFinalStatus; import com.iexec.commons.containers.DockerRunRequest; import com.iexec.commons.containers.DockerRunResponse; -import com.iexec.commons.containers.SgxDriverMode; import com.iexec.commons.poco.task.TaskDescription; import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; import com.iexec.worker.workflow.WorkflowError; @@ -44,19 +42,15 @@ public class AppComputeService { private final WorkerConfigurationService workerConfigService; private final DockerService dockerService; private final TeeServicesManager teeServicesManager; - private final SgxService sgxService; private final ComputeDurationsService appComputeDurationsService; - public AppComputeService( - WorkerConfigurationService workerConfigService, - DockerService dockerService, - TeeServicesManager teeServicesManager, - SgxService sgxService, - ComputeDurationsService appComputeDurationsService) { + public AppComputeService(final WorkerConfigurationService workerConfigService, + final DockerService dockerService, + final TeeServicesManager teeServicesManager, + final ComputeDurationsService appComputeDurationsService) { this.workerConfigService = workerConfigService; this.dockerService = dockerService; this.teeServicesManager = teeServicesManager; - this.sgxService = sgxService; this.appComputeDurationsService = appComputeDurationsService; } @@ -67,25 +61,22 @@ public AppComputeResponse runCompute(final TaskDescription taskDescription) { binds.add(Bind.parse(dockerService.getInputBind(chainTaskId))); binds.add(Bind.parse(dockerService.getIexecOutBind(chainTaskId))); - final SgxDriverMode sgxDriverMode; final List env; + final HostConfig hostConfig; if (taskDescription.requiresSgx()) { final TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); env = teeService.buildComputeDockerEnv(taskDescription); binds.addAll(teeService.getAdditionalBindings().stream().map(Bind::parse).toList()); - sgxDriverMode = sgxService.getSgxDriverMode(); + hostConfig = HostConfig.newHostConfig() + .withBinds(binds) + .withDevices(teeService.getDevices()) + .withNetworkMode(workerConfigService.getDockerNetworkName()); } else { env = IexecEnvUtils.getComputeStageEnvList(taskDescription); - sgxDriverMode = SgxDriverMode.NONE; + hostConfig = HostConfig.newHostConfig() + .withBinds(binds); } - final HostConfig hostConfig = HostConfig.newHostConfig() - .withBinds(binds) - .withDevices(sgxService.getSgxDevices()); - // Enclave should be able to connect to the LAS - if (taskDescription.requiresSgx()) { - hostConfig.withNetworkMode(workerConfigService.getDockerNetworkName()); - } final DockerRunRequest runRequest = DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(chainTaskId) @@ -94,7 +85,6 @@ public AppComputeResponse runCompute(final TaskDescription taskDescription) { .cmd(taskDescription.getDealParams().getIexecArgs()) .env(env) .maxExecutionTime(taskDescription.getMaxExecutionTime()) - .sgxDriverMode(sgxDriverMode) .build(); final DockerRunResponse dockerResponse = dockerService.run(runRequest); final Duration executionDuration = dockerResponse.getExecutionDuration(); diff --git a/src/main/java/com/iexec/worker/compute/post/PostComputeService.java b/src/main/java/com/iexec/worker/compute/post/PostComputeService.java index a8b8ab9ff..159178e30 100644 --- a/src/main/java/com/iexec/worker/compute/post/PostComputeService.java +++ b/src/main/java/com/iexec/worker/compute/post/PostComputeService.java @@ -26,7 +26,6 @@ import com.iexec.commons.containers.DockerRunRequest; import com.iexec.commons.containers.DockerRunResponse; import com.iexec.commons.poco.task.TaskDescription; -import com.iexec.sms.api.TeeSessionGenerationResponse; import com.iexec.sms.api.config.TeeAppProperties; import com.iexec.sms.api.config.TeeServicesProperties; import com.iexec.worker.compute.ComputeExitCauseService; @@ -34,7 +33,6 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; import com.iexec.worker.tee.TeeServicesPropertiesService; @@ -47,7 +45,6 @@ import java.nio.file.attribute.BasicFileAttributes; import java.time.Duration; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; @@ -61,23 +58,19 @@ public class PostComputeService { private final WorkerConfigurationService workerConfigService; private final DockerService dockerService; private final TeeServicesManager teeServicesManager; - private final SgxService sgxService; private final ComputeExitCauseService computeExitCauseService; private final TeeServicesPropertiesService teeServicesPropertiesService; private final ComputeDurationsService postComputeDurationsService; - public PostComputeService( - WorkerConfigurationService workerConfigService, - DockerService dockerService, - TeeServicesManager teeServicesManager, - SgxService sgxService, - ComputeExitCauseService computeExitCauseService, - TeeServicesPropertiesService teeServicesPropertiesService, - ComputeDurationsService postComputeDurationsService) { + public PostComputeService(final WorkerConfigurationService workerConfigService, + final DockerService dockerService, + final TeeServicesManager teeServicesManager, + final ComputeExitCauseService computeExitCauseService, + final TeeServicesPropertiesService teeServicesPropertiesService, + final ComputeDurationsService postComputeDurationsService) { this.workerConfigService = workerConfigService; this.dockerService = dockerService; this.teeServicesManager = teeServicesManager; - this.sgxService = sgxService; this.computeExitCauseService = computeExitCauseService; this.teeServicesPropertiesService = teeServicesPropertiesService; this.postComputeDurationsService = postComputeDurationsService; @@ -162,13 +155,12 @@ public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) { } public PostComputeResponse runTeePostCompute(final TaskDescription taskDescription) { - String chainTaskId = taskDescription.getChainTaskId(); + final String chainTaskId = taskDescription.getChainTaskId(); - TeeServicesProperties properties = - teeServicesPropertiesService.getTeeServicesProperties(chainTaskId); + final TeeServicesProperties properties = teeServicesPropertiesService.getTeeServicesProperties(chainTaskId); final TeeAppProperties postComputeProperties = properties.getPostComputeProperties(); - String postComputeImage = postComputeProperties.getImage(); + final String postComputeImage = postComputeProperties.getImage(); if (!dockerService.getClient().isImagePresent(postComputeImage)) { log.error("Tee post-compute image not found locally [chainTaskId:{}]", chainTaskId); @@ -176,21 +168,20 @@ public PostComputeResponse runTeePostCompute(final TaskDescription taskDescripti .exitCauses(List.of(new WorkflowError(ReplicateStatusCause.POST_COMPUTE_IMAGE_MISSING))) .build(); } - TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); - List env = teeService - .buildPostComputeDockerEnv(taskDescription); - List binds = Stream.of( - Collections.singletonList(dockerService.getIexecOutBind(chainTaskId)), + final TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); + final List env = teeService.buildPostComputeDockerEnv(taskDescription); + final List binds = Stream.of( + List.of(dockerService.getIexecOutBind(chainTaskId)), teeService.getAdditionalBindings()) .flatMap(Collection::stream) .map(Bind::parse) .toList(); - HostConfig hostConfig = HostConfig.newHostConfig() + final HostConfig hostConfig = HostConfig.newHostConfig() .withBinds(binds) - .withDevices(sgxService.getSgxDevices()) + .withDevices(teeService.getDevices()) .withNetworkMode(workerConfigService.getDockerNetworkName()); - DockerRunRequest request = DockerRunRequest.builder() + final DockerRunRequest request = DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(chainTaskId) .containerName(getTaskTeePostComputeContainerName(chainTaskId)) @@ -198,9 +189,8 @@ public PostComputeResponse runTeePostCompute(final TaskDescription taskDescripti .entrypoint(postComputeProperties.getEntrypoint()) .maxExecutionTime(taskDescription.getMaxExecutionTime()) .env(env) - .sgxDriverMode(sgxService.getSgxDriverMode()) .build(); - DockerRunResponse dockerResponse = dockerService.run(request); + final DockerRunResponse dockerResponse = dockerService.run(request); final Duration executionDuration = dockerResponse.getExecutionDuration(); if (executionDuration != null) { postComputeDurationsService.addDurationForTask(chainTaskId, executionDuration.toMillis()); diff --git a/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java b/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java index 7af853032..b7aa47597 100644 --- a/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java +++ b/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java @@ -31,7 +31,7 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; +import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; import com.iexec.worker.tee.TeeServicesPropertiesService; import com.iexec.worker.workflow.WorkflowError; @@ -40,7 +40,6 @@ import org.springframework.util.unit.DataSize; import java.time.Duration; -import java.util.Collections; import java.util.List; import java.util.concurrent.TimeoutException; @@ -51,23 +50,19 @@ public class PreComputeService { private final DockerService dockerService; private final TeeServicesManager teeServicesManager; private final WorkerConfigurationService workerConfigService; - private final SgxService sgxService; private final ComputeExitCauseService computeExitCauseService; private final TeeServicesPropertiesService teeServicesPropertiesService; private final ComputeDurationsService preComputeDurationsService; - public PreComputeService( - DockerService dockerService, - TeeServicesManager teeServicesManager, - WorkerConfigurationService workerConfigService, - SgxService sgxService, - ComputeExitCauseService computeExitCauseService, - TeeServicesPropertiesService teeServicesPropertiesService, - ComputeDurationsService preComputeDurationsService) { + public PreComputeService(final DockerService dockerService, + final TeeServicesManager teeServicesManager, + final WorkerConfigurationService workerConfigService, + final ComputeExitCauseService computeExitCauseService, + final TeeServicesPropertiesService teeServicesPropertiesService, + final ComputeDurationsService preComputeDurationsService) { this.dockerService = dockerService; this.teeServicesManager = teeServicesManager; this.workerConfigService = workerConfigService; - this.sgxService = sgxService; this.computeExitCauseService = computeExitCauseService; this.teeServicesPropertiesService = teeServicesPropertiesService; this.preComputeDurationsService = preComputeDurationsService; @@ -159,28 +154,27 @@ private List getExitCauses(final String chainTaskId, final Intege * @return pre-compute exit code */ private Integer prepareTeeInputData(final TaskDescription taskDescription) throws TimeoutException { - String chainTaskId = taskDescription.getChainTaskId(); + final String chainTaskId = taskDescription.getChainTaskId(); log.info("Preparing tee input data [chainTaskId:{}]", chainTaskId); - TeeServicesProperties properties = - teeServicesPropertiesService.getTeeServicesProperties(chainTaskId); + final TeeServicesProperties properties = teeServicesPropertiesService.getTeeServicesProperties(chainTaskId); // check that docker image is present final TeeAppProperties preComputeProperties = properties.getPreComputeProperties(); - String preComputeImage = preComputeProperties.getImage(); + final String preComputeImage = preComputeProperties.getImage(); if (!dockerService.getClient().isImagePresent(preComputeImage)) { log.error("Tee pre-compute image not found locally [chainTaskId:{}]", chainTaskId); return null; } // run container - List env = teeServicesManager.getTeeService(taskDescription.getTeeFramework()) - .buildPreComputeDockerEnv(taskDescription); - List binds = Collections.singletonList(Bind.parse(dockerService.getInputBind(chainTaskId))); - HostConfig hostConfig = HostConfig.newHostConfig() + final TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); + final List env = teeService.buildPreComputeDockerEnv(taskDescription); + final List binds = List.of(Bind.parse(dockerService.getInputBind(chainTaskId))); + final HostConfig hostConfig = HostConfig.newHostConfig() .withBinds(binds) - .withDevices(sgxService.getSgxDevices()) + .withDevices(teeService.getDevices()) .withNetworkMode(workerConfigService.getDockerNetworkName()); - DockerRunRequest request = DockerRunRequest.builder() + final DockerRunRequest request = DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(chainTaskId) .containerName(getTeePreComputeContainerName(chainTaskId)) @@ -188,9 +182,8 @@ private Integer prepareTeeInputData(final TaskDescription taskDescription) throw .entrypoint(preComputeProperties.getEntrypoint()) .maxExecutionTime(taskDescription.getMaxExecutionTime()) .env(env) - .sgxDriverMode(sgxService.getSgxDriverMode()) .build(); - DockerRunResponse dockerResponse = dockerService.run(request); + final DockerRunResponse dockerResponse = dockerService.run(request); final Duration executionDuration = dockerResponse.getExecutionDuration(); if (executionDuration != null) { preComputeDurationsService.addDurationForTask(chainTaskId, executionDuration.toMillis()); diff --git a/src/main/java/com/iexec/worker/tee/TeeService.java b/src/main/java/com/iexec/worker/tee/TeeService.java index f733f6ce6..5e8400432 100644 --- a/src/main/java/com/iexec/worker/tee/TeeService.java +++ b/src/main/java/com/iexec/worker/tee/TeeService.java @@ -16,12 +16,12 @@ package com.iexec.worker.tee; +import com.github.dockerjava.api.model.Device; import com.iexec.commons.poco.chain.WorkerpoolAuthorization; import com.iexec.commons.poco.task.TaskDescription; import com.iexec.sms.api.SmsClientCreationException; import com.iexec.sms.api.TeeSessionGenerationError; import com.iexec.sms.api.TeeSessionGenerationResponse; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.sms.SmsService; import com.iexec.worker.sms.TeeSessionGenerationException; import com.iexec.worker.workflow.WorkflowError; @@ -36,24 +36,17 @@ @Slf4j public abstract class TeeService { - private final SgxService sgxService; private final SmsService smsService; protected final TeeServicesPropertiesService teeServicesPropertiesService; private final Map teeSessions = new ConcurrentHashMap<>(); - protected TeeService(final SgxService sgxService, - final SmsService smsService, + protected TeeService(final SmsService smsService, final TeeServicesPropertiesService teeServicesPropertiesService) { - this.sgxService = sgxService; this.smsService = smsService; this.teeServicesPropertiesService = teeServicesPropertiesService; } - public boolean isTeeEnabled() { - return sgxService.isSgxEnabled(); - } - public List areTeePrerequisitesMetForTask(final String chainTaskId) { if (!isTeeEnabled()) { return List.of(new WorkflowError(TEE_NOT_SUPPORTED)); @@ -98,6 +91,8 @@ public TeeSessionGenerationResponse getTeeSession(final String chainTaskId) { return teeSessions.get(chainTaskId); } + public abstract boolean isTeeEnabled(); + /** * Start any required service(s) to use TEE with selected technology for given task. * @@ -114,6 +109,8 @@ public TeeSessionGenerationResponse getTeeSession(final String chainTaskId) { public abstract Collection getAdditionalBindings(); + public abstract List getDevices(); + // region Purge /** diff --git a/src/main/java/com/iexec/worker/tee/gramine/TeeGramineService.java b/src/main/java/com/iexec/worker/tee/gramine/TeeGramineService.java index cbe1ec679..c9961e5fb 100644 --- a/src/main/java/com/iexec/worker/tee/gramine/TeeGramineService.java +++ b/src/main/java/com/iexec/worker/tee/gramine/TeeGramineService.java @@ -16,6 +16,7 @@ package com.iexec.worker.tee.gramine; +import com.github.dockerjava.api.model.Device; import com.iexec.common.lifecycle.purge.Purgeable; import com.iexec.commons.poco.task.TaskDescription; import com.iexec.sms.api.TeeSessionGenerationResponse; @@ -38,10 +39,18 @@ public class TeeGramineService extends TeeService implements Purgeable { private static final String SPS_SESSION_ENV_VAR = "session"; private static final String AESMD_SOCKET = "/var/run/aesmd/aesm.socket"; - public TeeGramineService(SgxService sgxService, - SmsService smsService, - TeeServicesPropertiesService teeServicesPropertiesService) { - super(sgxService, smsService, teeServicesPropertiesService); + private final SgxService sgxService; + + public TeeGramineService(final SgxService sgxService, + final SmsService smsService, + final TeeServicesPropertiesService teeServicesPropertiesService) { + super(smsService, teeServicesPropertiesService); + this.sgxService = sgxService; + } + + @Override + public boolean isTeeEnabled() { + return sgxService.isSgxEnabled(); } @Override @@ -72,6 +81,11 @@ public Collection getAdditionalBindings() { return bindings; } + @Override + public List getDevices() { + return sgxService.getSgxDevices(); + } + private List getDockerEnv(final TeeSessionGenerationResponse session) { return List.of( SPS_URL_ENV_VAR + "=" + session.getSecretProvisioningUrl(), diff --git a/src/main/java/com/iexec/worker/tee/scone/LasService.java b/src/main/java/com/iexec/worker/tee/scone/LasService.java index f231f842f..50765d34f 100644 --- a/src/main/java/com/iexec/worker/tee/scone/LasService.java +++ b/src/main/java/com/iexec/worker/tee/scone/LasService.java @@ -60,16 +60,13 @@ synchronized boolean start() { return true; } - HostConfig hostConfig = HostConfig.newHostConfig() + final HostConfig hostConfig = HostConfig.newHostConfig() .withDevices(sgxService.getSgxDevices()) .withNetworkMode(workerConfigService.getDockerNetworkName()); - DockerRunRequest dockerRunRequest = DockerRunRequest.builder() + final DockerRunRequest dockerRunRequest = DockerRunRequest.builder() .hostConfig(hostConfig) .containerName(containerName) .imageUri(imageUri) - // pre-compute, application & post-compute enclaves will be - // able to talk to the LAS via this network - .sgxDriverMode(sgxService.getSgxDriverMode()) .maxExecutionTime(0) .build(); if (!imageUri.contains(sconeConfig.getRegistry().getName())) { @@ -77,7 +74,7 @@ synchronized boolean start() { imageUri, sconeConfig.getRegistry().getName()); return false; } - DockerClientInstance client; + final DockerClientInstance client; try { client = dockerService.getClient( sconeConfig.getRegistry().getName(), diff --git a/src/main/java/com/iexec/worker/tee/scone/TeeSconeService.java b/src/main/java/com/iexec/worker/tee/scone/TeeSconeService.java index 7ef730016..735157b05 100644 --- a/src/main/java/com/iexec/worker/tee/scone/TeeSconeService.java +++ b/src/main/java/com/iexec/worker/tee/scone/TeeSconeService.java @@ -16,6 +16,7 @@ package com.iexec.worker.tee.scone; +import com.github.dockerjava.api.model.Device; import com.iexec.common.lifecycle.purge.Purgeable; import com.iexec.commons.poco.task.TaskDescription; import com.iexec.commons.poco.tee.TeeEnclaveConfiguration; @@ -48,14 +49,15 @@ public class TeeSconeService extends TeeService implements Purgeable { private static final String SCONE_LOG = "SCONE_LOG"; private static final String SCONE_VERSION = "SCONE_VERSION"; + private final SgxService sgxService; private final LasServicesManager lasServicesManager; - public TeeSconeService( - SgxService sgxService, - SmsService smsService, - TeeServicesPropertiesService teeServicesPropertiesService, - LasServicesManager lasServicesManager) { - super(sgxService, smsService, teeServicesPropertiesService); + public TeeSconeService(final SgxService sgxService, + final SmsService smsService, + final TeeServicesPropertiesService teeServicesPropertiesService, + final LasServicesManager lasServicesManager) { + super(smsService, teeServicesPropertiesService); + this.sgxService = sgxService; this.lasServicesManager = lasServicesManager; if (isTeeEnabled()) { @@ -65,6 +67,11 @@ public TeeSconeService( } } + @Override + public boolean isTeeEnabled() { + return sgxService.isSgxEnabled(); + } + @Override public List areTeePrerequisitesMetForTask(final String chainTaskId) { final List teePrerequisiteIssues = super.areTeePrerequisitesMetForTask(chainTaskId); @@ -76,7 +83,7 @@ public List areTeePrerequisitesMetForTask(final String chainTaskI } @Override - public boolean prepareTeeForTask(String chainTaskId) { + public boolean prepareTeeForTask(final String chainTaskId) { return lasServicesManager.startLasService(chainTaskId); } @@ -115,6 +122,11 @@ public Collection getAdditionalBindings() { return Collections.emptySet(); } + @Override + public List getDevices() { + return sgxService.getSgxDevices(); + } + private List getDockerEnv(String chainTaskId, String sconeConfigId, long sconeHeap, diff --git a/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java b/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java index 69ac33b91..b936b272d 100644 --- a/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java +++ b/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java @@ -24,7 +24,6 @@ import com.iexec.commons.containers.DockerRunFinalStatus; import com.iexec.commons.containers.DockerRunRequest; import com.iexec.commons.containers.DockerRunResponse; -import com.iexec.commons.containers.SgxDriverMode; import com.iexec.commons.poco.chain.DealParams; import com.iexec.commons.poco.order.OrderTag; import com.iexec.commons.poco.task.TaskDescription; @@ -34,10 +33,8 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; -import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; @@ -48,6 +45,7 @@ import java.time.Duration; import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -91,8 +89,6 @@ TaskDescription.TaskDescriptionBuilder getTaskDescriptionBuilder(final OrderTag @Mock private TeeServicesManager teeServicesManager; @Mock - private SgxService sgxService; - @Mock private ComputeDurationsService appComputeDurationsService; @Mock @@ -114,7 +110,7 @@ void shouldRunCompute() { final AppComputeResponse appComputeResponse = appComputeService.runCompute(taskDescription); - Assertions.assertThat(appComputeResponse.isSuccessful()).isTrue(); + assertThat(appComputeResponse.isSuccessful()).isTrue(); verify(dockerService).run(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DockerRunRequest.class); @@ -122,9 +118,8 @@ void shouldRunCompute() { DockerRunRequest dockerRunRequest = argumentCaptor.getAllValues().get(0); HostConfig hostConfig = HostConfig.newHostConfig() - .withBinds(Bind.parse(inputBind), Bind.parse(iexecOutBind)) - .withDevices(List.of()); - Assertions.assertThat(dockerRunRequest).isEqualTo( + .withBinds(Bind.parse(inputBind), Bind.parse(iexecOutBind)); + assertThat(dockerRunRequest).isEqualTo( DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(CHAIN_TASK_ID) @@ -132,7 +127,6 @@ void shouldRunCompute() { .imageUri(APP_URI) .maxExecutionTime(MAX_EXECUTION_TIME) .env(IexecEnvUtils.getComputeStageEnvList(taskDescription)) - .sgxDriverMode(SgxDriverMode.NONE) .build() ); } @@ -159,13 +153,12 @@ void shouldRunComputeWithTeeAndConnectAppToLas() { .executionDuration(Duration.ofSeconds(10)) .build(); when(dockerService.run(any())).thenReturn(expectedDockerRunResponse); - when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY); List devices = List.of(Device.parse("/dev/isgx")); - when(sgxService.getSgxDevices()).thenReturn(devices); + when(teeMockedService.getDevices()).thenReturn(devices); AppComputeResponse appComputeResponse = appComputeService.runCompute(taskDescription); - Assertions.assertThat(appComputeResponse.isSuccessful()).isTrue(); + assertThat(appComputeResponse.isSuccessful()).isTrue(); verify(dockerService).run(any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DockerRunRequest.class); @@ -176,7 +169,7 @@ void shouldRunComputeWithTeeAndConnectAppToLas() { .withBinds(Bind.parse(inputBind), Bind.parse(iexecOutBind)) .withDevices(devices) .withNetworkMode(lasNetworkName); - Assertions.assertThat(dockerRunRequest).isEqualTo( + assertThat(dockerRunRequest).isEqualTo( DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(CHAIN_TASK_ID) @@ -184,7 +177,6 @@ void shouldRunComputeWithTeeAndConnectAppToLas() { .imageUri(APP_URI) .maxExecutionTime(MAX_EXECUTION_TIME) .env(env) - .sgxDriverMode(SgxDriverMode.LEGACY) .build() ); } @@ -201,7 +193,7 @@ void shouldRunComputeWithFailDockerResponse() { AppComputeResponse appComputeResponse = appComputeService.runCompute(taskDescription); - Assertions.assertThat(appComputeResponse.isSuccessful()).isFalse(); + assertThat(appComputeResponse.isSuccessful()).isFalse(); verify(dockerService).run(any()); } diff --git a/src/test/java/com/iexec/worker/compute/post/PostComputeServiceTests.java b/src/test/java/com/iexec/worker/compute/post/PostComputeServiceTests.java index 67ea2671f..d4ca0b707 100644 --- a/src/test/java/com/iexec/worker/compute/post/PostComputeServiceTests.java +++ b/src/test/java/com/iexec/worker/compute/post/PostComputeServiceTests.java @@ -25,7 +25,6 @@ import com.iexec.commons.containers.DockerRunFinalStatus; import com.iexec.commons.containers.DockerRunRequest; import com.iexec.commons.containers.DockerRunResponse; -import com.iexec.commons.containers.SgxDriverMode; import com.iexec.commons.containers.client.DockerClientInstance; import com.iexec.commons.poco.task.TaskDescription; import com.iexec.sms.api.config.TeeAppProperties; @@ -35,7 +34,6 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; import com.iexec.worker.tee.TeeServicesPropertiesService; @@ -107,8 +105,6 @@ class PostComputeServiceTests { @Mock private DockerClientInstance dockerClientInstanceMock; @Mock - private SgxService sgxService; - @Mock private ComputeExitCauseService computeExitCauseService; @Mock private TeeServicesPropertiesService teeServicesPropertiesService; @@ -215,7 +211,6 @@ void prepareMocksForTeePostCompute(DockerRunResponse dockerRunResponse) { when(dockerService.getIexecOutBind(CHAIN_TASK_ID)).thenReturn(iexecOutBind); when(workerConfigService.getWorkerName()).thenReturn(WORKER_NAME); when(workerConfigService.getDockerNetworkName()).thenReturn("lasNetworkName"); - when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY); when(dockerService.run(any())).thenReturn(dockerRunResponse); } @@ -234,7 +229,7 @@ void shouldRunTeePostComputeAndConnectToLasNetwork() { .build(); prepareMocksForTeePostCompute(expectedDockerRunResponse); List devices = List.of(Device.parse("/dev/isgx")); - when(sgxService.getSgxDevices()).thenReturn(devices); + when(teeMockedService.getDevices()).thenReturn(devices); PostComputeResponse postComputeResponse = postComputeService.runTeePostCompute(taskDescription); @@ -261,7 +256,6 @@ void shouldRunTeePostComputeAndConnectToLasNetwork() { .entrypoint(TEE_POST_COMPUTE_ENTRYPOINT) .maxExecutionTime(MAX_EXECUTION_TIME) .env(env) - .sgxDriverMode(SgxDriverMode.LEGACY) .build() ); } diff --git a/src/test/java/com/iexec/worker/compute/pre/PreComputeServiceTests.java b/src/test/java/com/iexec/worker/compute/pre/PreComputeServiceTests.java index d79c1b9a9..68fa73d14 100644 --- a/src/test/java/com/iexec/worker/compute/pre/PreComputeServiceTests.java +++ b/src/test/java/com/iexec/worker/compute/pre/PreComputeServiceTests.java @@ -20,7 +20,6 @@ import com.iexec.commons.containers.DockerRunFinalStatus; import com.iexec.commons.containers.DockerRunRequest; import com.iexec.commons.containers.DockerRunResponse; -import com.iexec.commons.containers.SgxDriverMode; import com.iexec.commons.containers.client.DockerClientInstance; import com.iexec.commons.poco.chain.DealParams; import com.iexec.commons.poco.task.TaskDescription; @@ -34,7 +33,6 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.metric.ComputeDurationsService; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.tee.TeeService; import com.iexec.worker.tee.TeeServicesManager; import com.iexec.worker.tee.TeeServicesPropertiesService; @@ -99,8 +97,6 @@ class PreComputeServiceTests { @Mock private DockerClientInstance dockerClientInstanceMock; @Mock - private SgxService sgxService; - @Mock private ComputeExitCauseService computeExitCauseService; @Mock private TeeServicesPropertiesService teeServicesPropertiesService; @@ -136,7 +132,6 @@ void prepareMocksForPreCompute(final TaskDescription taskDescription, DockerRunR .thenReturn(List.of("env")); when(dockerService.getInputBind(chainTaskId)).thenReturn(IEXEC_IN_BIND); when(workerConfigService.getDockerNetworkName()).thenReturn(network); - when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.LEGACY); when(dockerService.run(any())).thenReturn(dockerRunResponse); } @@ -145,7 +140,6 @@ void verifyDockerRun() { DockerRunRequest capturedRequest = captor.getValue(); assertThat(capturedRequest.getImageUri()).isEqualTo(PRE_COMPUTE_IMAGE); assertThat(capturedRequest.getEntrypoint()).isEqualTo(PRE_COMPUTE_ENTRYPOINT); - assertThat(capturedRequest.getSgxDriverMode()).isEqualTo(SgxDriverMode.LEGACY); assertThat(capturedRequest.getHostConfig().getNetworkMode()).isEqualTo(network); assertThat(capturedRequest.getHostConfig().getBinds()[0]).hasToString(IEXEC_IN_BIND + ":rw"); } diff --git a/src/test/java/com/iexec/worker/tee/TeeServiceMock.java b/src/test/java/com/iexec/worker/tee/TeeServiceMock.java index c2fb01276..7f7315501 100644 --- a/src/test/java/com/iexec/worker/tee/TeeServiceMock.java +++ b/src/test/java/com/iexec/worker/tee/TeeServiceMock.java @@ -16,8 +16,8 @@ package com.iexec.worker.tee; +import com.github.dockerjava.api.model.Device; import com.iexec.commons.poco.task.TaskDescription; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.sms.SmsService; import java.util.Collection; @@ -25,10 +25,14 @@ class TeeServiceMock extends TeeService { - protected TeeServiceMock(SgxService sgxService, - SmsService smsService, + protected TeeServiceMock(SmsService smsService, TeeServicesPropertiesService teeServicesPropertiesService) { - super(sgxService, smsService, teeServicesPropertiesService); + super(smsService, teeServicesPropertiesService); + } + + @Override + public boolean isTeeEnabled() { + return true; } @Override @@ -53,6 +57,11 @@ public List buildPostComputeDockerEnv(TaskDescription taskDescription) { @Override public Collection getAdditionalBindings() { - return null; + return List.of(); + } + + @Override + public List getDevices() { + return List.of(); } } diff --git a/src/test/java/com/iexec/worker/tee/TeeServiceTests.java b/src/test/java/com/iexec/worker/tee/TeeServiceTests.java index 123554021..4310a5f47 100644 --- a/src/test/java/com/iexec/worker/tee/TeeServiceTests.java +++ b/src/test/java/com/iexec/worker/tee/TeeServiceTests.java @@ -17,19 +17,14 @@ package com.iexec.worker.tee; import com.iexec.commons.poco.chain.WorkerpoolAuthorization; -import com.iexec.sms.api.SmsClient; -import com.iexec.sms.api.SmsClientCreationException; import com.iexec.sms.api.TeeSessionGenerationError; import com.iexec.sms.api.TeeSessionGenerationResponse; -import com.iexec.worker.sgx.SgxService; import com.iexec.worker.sms.SmsService; import com.iexec.worker.sms.TeeSessionGenerationException; -import com.iexec.worker.workflow.WorkflowError; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.boot.test.system.CapturedOutput; import org.springframework.boot.test.system.OutputCaptureExtension; @@ -38,12 +33,10 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import static com.iexec.common.replicate.ReplicateStatusCause.*; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @ExtendWith(OutputCaptureExtension.class) @@ -53,88 +46,12 @@ class TeeServiceTests { .chainTaskId(CHAIN_TASK_ID) .build(); - @Mock - SgxService sgxService; @Mock SmsService smsService; - @Mock - SmsClient smsClient; - @Mock - TeeServicesPropertiesService teeServicesPropertiesService; - @Spy @InjectMocks TeeServiceMock teeService; - // region isTeeEnabled - @Test - void shouldTeeBeEnabled() { - when(sgxService.isSgxEnabled()).thenReturn(true); - - assertTrue(teeService.isTeeEnabled()); - - verify(sgxService).isSgxEnabled(); - } - - @Test - void shouldTeeNotBeEnabled() { - when(sgxService.isSgxEnabled()).thenReturn(false); - - assertFalse(teeService.isTeeEnabled()); - - verify(sgxService).isSgxEnabled(); - } - // endregion - - // region areTeePrerequisitesMetForTask - @Test - void shouldTeePrerequisitesBeMet() { - when(teeService.isTeeEnabled()).thenReturn(true); - when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); - when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)).thenReturn(null); - - assertThat(teeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) - .isEmpty(); - } - - @Test - void shouldTeePrerequisitesNotBeMetSinceTeeNotEnabled() { - when(teeService.isTeeEnabled()).thenReturn(false); - - assertThat(teeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) - .containsExactly(new WorkflowError(TEE_NOT_SUPPORTED)); - } - - @Test - void shouldTeePrerequisitesNotBeMetSinceSmsClientCantBeLoaded() { - when(teeService.isTeeEnabled()).thenReturn(true); - when(smsService.getSmsClient(CHAIN_TASK_ID)).thenThrow(SmsClientCreationException.class); - - assertThat(teeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) - .containsExactly(new WorkflowError(UNKNOWN_SMS)); - } - - @Test - void shouldTeePrerequisitesNotBeMetSinceTeeEnclaveConfigurationIsNull() { - when(teeService.isTeeEnabled()).thenReturn(true); - when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); - when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)).thenThrow(NullPointerException.class); - - assertThat(teeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) - .containsExactly(new WorkflowError(PRE_COMPUTE_MISSING_ENCLAVE_CONFIGURATION)); - } - - @Test - void shouldTeePrerequisitesNotBeMetSinceTeeWorkflowConfigurationCantBeLoaded() { - when(teeService.isTeeEnabled()).thenReturn(true); - when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); - when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)).thenThrow(RuntimeException.class); - - assertThat(teeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) - .containsExactly(new WorkflowError(GET_TEE_SERVICES_CONFIGURATION_FAILED)); - } - // endregion - // region TEE sessions cache @Test void shouldAddTeeSessionGenerationResponseToCache() throws TeeSessionGenerationException { diff --git a/src/test/java/com/iexec/worker/tee/gramine/TeeGramineServiceTests.java b/src/test/java/com/iexec/worker/tee/gramine/TeeGramineServiceTests.java index 4b2309a7c..bf858e11e 100644 --- a/src/test/java/com/iexec/worker/tee/gramine/TeeGramineServiceTests.java +++ b/src/test/java/com/iexec/worker/tee/gramine/TeeGramineServiceTests.java @@ -37,9 +37,8 @@ import java.util.concurrent.ConcurrentHashMap; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class TeeGramineServiceTests { @@ -60,13 +59,26 @@ class TeeGramineServiceTests { @InjectMocks TeeGramineService teeGramineService; + // region isTeeEnabled + @Test + void shouldTeeBeEnabled() { + when(sgxService.isSgxEnabled()).thenReturn(true); + assertThat(teeGramineService.isTeeEnabled()).isTrue(); + } + + @Test + void shouldTeeNotBeEnabled() { + when(sgxService.isSgxEnabled()).thenReturn(false); + assertThat(teeGramineService.isTeeEnabled()).isFalse(); + } + // endregion + // region prepareTeeForTask @ParameterizedTest @NullSource @ValueSource(strings = {"", "0x123", "chainTaskId"}) void shouldPrepareTeeForTask(String chainTaskId) { - assertTrue(teeGramineService.prepareTeeForTask(chainTaskId)); - + assertThat(teeGramineService.prepareTeeForTask(chainTaskId)).isTrue(); verifyNoInteractions(sgxService, smsClientProvider, teeServicesPropertiesService); } // endregion @@ -78,12 +90,7 @@ void shouldBuildPreComputeDockerEnv(String chainTaskId) { ReflectionTestUtils.setField(teeGramineService, "teeSessions", Map.of(chainTaskId, TEE_SESSION_GENERATION_RESPONSE)); final TaskDescription taskDescription = TaskDescription.builder().chainTaskId(chainTaskId).build(); final List env = teeGramineService.buildPreComputeDockerEnv(taskDescription); - - assertEquals(2, env.size()); - assertTrue(env.containsAll(List.of( - "sps=http://spsUrl", - "session=0x123_session_id" - ))); + assertThat(env).containsExactly("sps=http://spsUrl", "session=0x123_session_id"); } // endregion @@ -94,12 +101,7 @@ void shouldBuildComputeDockerEnv(String chainTaskId) { ReflectionTestUtils.setField(teeGramineService, "teeSessions", Map.of(chainTaskId, TEE_SESSION_GENERATION_RESPONSE)); final TaskDescription taskDescription = TaskDescription.builder().chainTaskId(chainTaskId).build(); final List env = teeGramineService.buildComputeDockerEnv(taskDescription); - - assertEquals(2, env.size()); - assertTrue(env.containsAll(List.of( - "sps=http://spsUrl", - "session=0x123_session_id" - ))); + assertThat(env).containsExactly("sps=http://spsUrl", "session=0x123_session_id"); } // endregion @@ -110,12 +112,7 @@ void shouldBuildPostComputeDockerEnv(String chainTaskId) { ReflectionTestUtils.setField(teeGramineService, "teeSessions", Map.of(chainTaskId, TEE_SESSION_GENERATION_RESPONSE)); final TaskDescription taskDescription = TaskDescription.builder().chainTaskId(chainTaskId).build(); final List env = teeGramineService.buildPostComputeDockerEnv(taskDescription); - - assertEquals(2, env.size()); - assertTrue(env.containsAll(List.of( - "sps=http://spsUrl", - "session=0x123_session_id" - ))); + assertThat(env).containsExactly("sps=http://spsUrl", "session=0x123_session_id"); } // endregion @@ -123,9 +120,7 @@ void shouldBuildPostComputeDockerEnv(String chainTaskId) { @Test void shouldGetAdditionalBindings() { final Collection bindings = teeGramineService.getAdditionalBindings(); - - assertEquals(1, bindings.size()); - assertTrue(bindings.contains("/var/run/aesmd/aesm.socket:/var/run/aesmd/aesm.socket")); + assertThat(bindings).containsExactly("/var/run/aesmd/aesm.socket:/var/run/aesmd/aesm.socket"); } // endregion diff --git a/src/test/java/com/iexec/worker/tee/scone/LasServiceTests.java b/src/test/java/com/iexec/worker/tee/scone/LasServiceTests.java index f0cadeaf0..2b32bfbd9 100644 --- a/src/test/java/com/iexec/worker/tee/scone/LasServiceTests.java +++ b/src/test/java/com/iexec/worker/tee/scone/LasServiceTests.java @@ -26,7 +26,6 @@ import com.iexec.worker.config.WorkerConfigurationService; import com.iexec.worker.docker.DockerService; import com.iexec.worker.sgx.SgxService; -import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -38,6 +37,7 @@ import java.util.Arrays; import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -80,7 +80,6 @@ void init() { } private void createLasServiceStubs() { - when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.NATIVE); when(sconeConfiguration.getRegistry()) .thenReturn(new SconeConfiguration.SconeRegistry(REGISTRY_NAME, REGISTRY_USERNAME, REGISTRY_PASSWORD)); when(dockerService.getClient(REGISTRY_NAME, REGISTRY_USERNAME, REGISTRY_PASSWORD)) @@ -100,12 +99,11 @@ void shouldStartLasService() { assertTrue(lasService.start()); verify(dockerService).run(dockerRunRequestArgumentCaptor.capture()); DockerRunRequest dockerRunRequest = dockerRunRequestArgumentCaptor.getValue(); - Assertions.assertThat(dockerRunRequest).isEqualTo( + assertThat(dockerRunRequest).isEqualTo( DockerRunRequest.builder() .hostConfig(HostConfig.newHostConfig().withDevices(devices)) .containerName(CONTAINER_NAME) .imageUri(IMAGE_URI) - .sgxDriverMode(SgxDriverMode.NATIVE) .maxExecutionTime(0) .build() ); @@ -128,7 +126,6 @@ void shouldStartLasServiceOnlyOnce() { @Test void shouldNotStartLasServiceSinceUnknownRegistry() { - when(sgxService.getSgxDriverMode()).thenReturn(SgxDriverMode.NATIVE); when(sconeConfiguration.getRegistry()) .thenReturn(new SconeConfiguration.SconeRegistry(REGISTRY_NAME, REGISTRY_USERNAME, REGISTRY_PASSWORD)); diff --git a/src/test/java/com/iexec/worker/tee/scone/TeeSconeServiceTests.java b/src/test/java/com/iexec/worker/tee/scone/TeeSconeServiceTests.java index c05f67fea..677d50318 100644 --- a/src/test/java/com/iexec/worker/tee/scone/TeeSconeServiceTests.java +++ b/src/test/java/com/iexec/worker/tee/scone/TeeSconeServiceTests.java @@ -25,13 +25,13 @@ import com.iexec.sms.api.config.TeeAppProperties; import com.iexec.worker.sgx.SgxService; import com.iexec.worker.sms.SmsService; +import com.iexec.worker.tee.TeeServicesPropertiesCreationException; import com.iexec.worker.tee.TeeServicesPropertiesService; import com.iexec.worker.workflow.WorkflowError; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.test.util.ReflectionTestUtils; @@ -60,7 +60,6 @@ class TeeSconeServiceTests { public static final long HEAP_SIZE = 1024L; @InjectMocks - @Spy private TeeSconeService teeSconeService; @Mock private SconeConfiguration sconeConfig; @@ -81,28 +80,42 @@ class TeeSconeServiceTests { @Mock private SmsClient smsClient; + // region isTeeEnabled + @Test + void shouldTeeBeEnabled() { + when(sgxService.isSgxEnabled()).thenReturn(true); + assertThat(teeSconeService.isTeeEnabled()).isTrue(); + } + + @Test + void shouldTeeNotBeEnabled() { + when(sgxService.isSgxEnabled()).thenReturn(false); + assertThat(teeSconeService.isTeeEnabled()).isFalse(); + } + // endregion + // region areTeePrerequisitesMetForTask @Test void shouldTeePrerequisiteMetForTask() { - doReturn(true).when(teeSconeService).isTeeEnabled(); - doReturn(smsClient).when(smsService).getSmsClient(CHAIN_TASK_ID); - doReturn(null).when(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); - doReturn(true).when(teeSconeService).prepareTeeForTask(CHAIN_TASK_ID); + when(sgxService.isSgxEnabled()).thenReturn(true); + when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); + when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)).thenReturn(null); + when(lasServicesManager.startLasService(CHAIN_TASK_ID)).thenReturn(true); final List teePrerequisitesIssue = teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); assertThat(teePrerequisitesIssue).isEmpty(); - verify(teeSconeService).isTeeEnabled(); + verify(sgxService, times(2)).isSgxEnabled(); verify(smsService).getSmsClient(CHAIN_TASK_ID); verify(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); - verify(teeSconeService).prepareTeeForTask(CHAIN_TASK_ID); + verify(lasServicesManager).startLasService(CHAIN_TASK_ID); } @Test void shouldTeePrerequisiteNotMetForTaskSinceTeeNotEnabled() { - doReturn(false).when(teeSconeService).isTeeEnabled(); + when(sgxService.isSgxEnabled()).thenReturn(false); final List teePrerequisitesIssue = teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); @@ -110,16 +123,14 @@ void shouldTeePrerequisiteNotMetForTaskSinceTeeNotEnabled() { assertThat(teePrerequisitesIssue) .containsExactly(new WorkflowError(TEE_NOT_SUPPORTED)); - verify(teeSconeService, times(1)).isTeeEnabled(); - verify(smsService, times(0)).getSmsClient(CHAIN_TASK_ID); - verify(teeServicesPropertiesService, times(0)).getTeeServicesProperties(CHAIN_TASK_ID); - verify(teeSconeService, times(0)).prepareTeeForTask(CHAIN_TASK_ID); + verify(sgxService, times(2)).isSgxEnabled(); + verifyNoInteractions(smsService, teeServicesPropertiesService, lasServicesManager); } @Test void shouldTeePrerequisiteNotMetForTaskSinceSmsClientCantBeLoaded() { - doReturn(true).when(teeSconeService).isTeeEnabled(); - doThrow(SmsClientCreationException.class).when(smsService).getSmsClient(CHAIN_TASK_ID); + when(sgxService.isSgxEnabled()).thenReturn(true); + when(smsService.getSmsClient(CHAIN_TASK_ID)).thenThrow(SmsClientCreationException.class); final List teePrerequisitesIssue = teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); @@ -127,17 +138,17 @@ void shouldTeePrerequisiteNotMetForTaskSinceSmsClientCantBeLoaded() { assertThat(teePrerequisitesIssue) .containsExactly(new WorkflowError(UNKNOWN_SMS)); - verify(teeSconeService, times(1)).isTeeEnabled(); - verify(smsService, times(1)).getSmsClient(CHAIN_TASK_ID); - verify(teeServicesPropertiesService, times(0)).getTeeServicesProperties(CHAIN_TASK_ID); - verify(teeSconeService, times(0)).prepareTeeForTask(CHAIN_TASK_ID); + verify(sgxService, times(2)).isSgxEnabled(); + verify(smsService).getSmsClient(CHAIN_TASK_ID); + verifyNoInteractions(teeServicesPropertiesService, lasServicesManager); } @Test void shouldTeePrerequisiteNotMetForTaskSinceTeeWorkflowConfigurationCantBeLoaded() { - doReturn(true).when(teeSconeService).isTeeEnabled(); - doReturn(smsClient).when(smsService).getSmsClient(CHAIN_TASK_ID); - doThrow(SmsClientCreationException.class).when(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); + when(sgxService.isSgxEnabled()).thenReturn(true); + when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); + when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)) + .thenThrow(TeeServicesPropertiesCreationException.class); final List teePrerequisitesIssue = teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); @@ -145,18 +156,35 @@ void shouldTeePrerequisiteNotMetForTaskSinceTeeWorkflowConfigurationCantBeLoaded assertThat(teePrerequisitesIssue) .containsExactly(new WorkflowError(GET_TEE_SERVICES_CONFIGURATION_FAILED)); - verify(teeSconeService, times(1)).isTeeEnabled(); - verify(smsService, times(1)).getSmsClient(CHAIN_TASK_ID); - verify(teeServicesPropertiesService, times(1)).getTeeServicesProperties(CHAIN_TASK_ID); - verify(teeSconeService, times(0)).prepareTeeForTask(CHAIN_TASK_ID); + verify(sgxService, times(2)).isSgxEnabled(); + verify(smsService).getSmsClient(CHAIN_TASK_ID); + verify(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); + verifyNoInteractions(lasServicesManager); + } + + @Test + void shouldTeePrerequisitesNotBeMetSinceTeeEnclaveConfigurationIsNull() { + when(sgxService.isSgxEnabled()).thenReturn(true); + when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); + when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)) + .thenThrow(NullPointerException.class); + + final List teePrerequisitesIssue = + teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); + assertThat(teePrerequisitesIssue) + .containsExactly(new WorkflowError(PRE_COMPUTE_MISSING_ENCLAVE_CONFIGURATION)); + verify(sgxService, times(2)).isSgxEnabled(); + verify(smsService).getSmsClient(CHAIN_TASK_ID); + verify(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); + verifyNoInteractions(lasServicesManager); } @Test void shouldTeePrerequisiteNotMetForTaskSinceCantPrepareTee() { - doReturn(true).when(teeSconeService).isTeeEnabled(); - doReturn(smsClient).when(smsService).getSmsClient(CHAIN_TASK_ID); - doReturn(null).when(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); - doReturn(false).when(teeSconeService).prepareTeeForTask(CHAIN_TASK_ID); + when(sgxService.isSgxEnabled()).thenReturn(true); + when(smsService.getSmsClient(CHAIN_TASK_ID)).thenReturn(smsClient); + when(teeServicesPropertiesService.getTeeServicesProperties(CHAIN_TASK_ID)).thenReturn(null); + when(lasServicesManager.startLasService(CHAIN_TASK_ID)).thenReturn(false); final List teePrerequisitesIssue = teeSconeService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID); @@ -164,10 +192,10 @@ void shouldTeePrerequisiteNotMetForTaskSinceCantPrepareTee() { assertThat(teePrerequisitesIssue) .containsExactly(new WorkflowError(TEE_PREPARATION_FAILED)); - verify(teeSconeService, times(1)).isTeeEnabled(); - verify(smsService, times(1)).getSmsClient(CHAIN_TASK_ID); - verify(teeServicesPropertiesService, times(1)).getTeeServicesProperties(CHAIN_TASK_ID); - verify(teeSconeService, times(1)).prepareTeeForTask(CHAIN_TASK_ID); + verify(sgxService, times(2)).isSgxEnabled(); + verify(smsService).getSmsClient(CHAIN_TASK_ID); + verify(teeServicesPropertiesService).getTeeServicesProperties(CHAIN_TASK_ID); + verify(lasServicesManager).startLasService(CHAIN_TASK_ID); } // endregion