diff --git a/gradle.properties b/gradle.properties index 5a4a7cd65..1dd0a7903 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,8 +4,8 @@ version=9.2.0 iexecCommonsPocoVersion=5.4.0 iexecCommonVersion=9.2.0 iexecCommonsContainersVersion=2.0.0 -iexecResultVersion=9.0.0 -iexecSmsVersion=9.0.0 -iexecCoreVersion=9.0.0 +iexecResultVersion=9.1.0 +iexecSmsVersion=9.3.0 +iexecCoreVersion=9.2.1 nexusUser nexusPassword diff --git a/src/main/java/com/iexec/worker/chain/ContributionService.java b/src/main/java/com/iexec/worker/chain/ContributionService.java index e10b21d9a..50d583014 100644 --- a/src/main/java/com/iexec/worker/chain/ContributionService.java +++ b/src/main/java/com/iexec/worker/chain/ContributionService.java @@ -169,7 +169,8 @@ public Contribution getContribution(ComputedFile computedFile) { String enclaveChallenge = workerpoolAuthorization.getEnclaveChallenge(); String enclaveSignature = computedFile.getEnclaveSignature(); - if (iexecHubService.getTaskDescription(chainTaskId).requiresSgx()) { + final TaskDescription taskDescription = iexecHubService.getTaskDescription(chainTaskId); + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { if (!enclaveAuthorizationService.isVerifiedEnclaveSignature( chainTaskId, resultHash, resultSeal, enclaveSignature, enclaveChallenge)) { log.error("Cannot get contribution with invalid enclave " + diff --git a/src/main/java/com/iexec/worker/compute/ComputeManagerService.java b/src/main/java/com/iexec/worker/compute/ComputeManagerService.java index e8c5338db..8434e60f8 100644 --- a/src/main/java/com/iexec/worker/compute/ComputeManagerService.java +++ b/src/main/java/com/iexec/worker/compute/ComputeManagerService.java @@ -160,10 +160,10 @@ public boolean isAppDownloaded(String imageUri) { * @see PreComputeService#runTeePreCompute(TaskDescription) */ public PreComputeResponse runPreCompute(final TaskDescription taskDescription) { - log.info("Running pre-compute [chainTaskId:{}, requiresSgx:{}]", - taskDescription.getChainTaskId(), taskDescription.requiresSgx()); + log.info("Running pre-compute [chainTaskId:{}, requiresSgx:{}, requiresTdx:{}]", + taskDescription.getChainTaskId(), taskDescription.requiresSgx(), taskDescription.requiresTdx()); - if (taskDescription.requiresSgx()) { + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { return preComputeService.runTeePreCompute(taskDescription); } return PreComputeResponse.builder().build(); @@ -178,8 +178,8 @@ public PreComputeResponse runPreCompute(final TaskDescription taskDescription) { */ public AppComputeResponse runCompute(final TaskDescription taskDescription) { final String chainTaskId = taskDescription.getChainTaskId(); - log.info("Running compute [chainTaskId:{}, requiresSgx:{}]", - chainTaskId, taskDescription.requiresSgx()); + log.info("Running compute [chainTaskId:{}, requiresSgx:{}, requiresTdx:{}]", + chainTaskId, taskDescription.requiresSgx(), taskDescription.requiresTdx()); final AppComputeResponse appComputeResponse = appComputeService.runCompute(taskDescription); @@ -211,11 +211,11 @@ private void writeLogs(String chainTaskId, String filename, String logs) { */ public PostComputeResponse runPostCompute(final TaskDescription taskDescription) { final String chainTaskId = taskDescription.getChainTaskId(); - log.info("Running post-compute [chainTaskId:{}, requiresSgx:{}]", - chainTaskId, taskDescription.requiresSgx()); + log.info("Running post-compute [chainTaskId:{}, requiresSgx:{}, requiresTdx:{}]", + chainTaskId, taskDescription.requiresSgx(), taskDescription.requiresTdx()); final PostComputeResponse postComputeResponse; - if (!taskDescription.requiresSgx()) { + if (!taskDescription.requiresSgx() && !taskDescription.requiresTdx()) { postComputeResponse = postComputeService.runStandardPostCompute(taskDescription); } else { postComputeResponse = postComputeService.runTeePostCompute(taskDescription); diff --git a/src/main/java/com/iexec/worker/compute/app/AppComputeService.java b/src/main/java/com/iexec/worker/compute/app/AppComputeService.java index 2ecc05051..018a26d28 100644 --- a/src/main/java/com/iexec/worker/compute/app/AppComputeService.java +++ b/src/main/java/com/iexec/worker/compute/app/AppComputeService.java @@ -63,7 +63,7 @@ public AppComputeResponse runCompute(final TaskDescription taskDescription) { final List env; final HostConfig hostConfig; - if (taskDescription.requiresSgx()) { + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { final TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); env = teeService.buildComputeDockerEnv(taskDescription); binds.addAll(teeService.getAdditionalBindings().stream().map(Bind::parse).toList()); diff --git a/src/main/java/com/iexec/worker/compute/post/PostComputeService.java b/src/main/java/com/iexec/worker/compute/post/PostComputeService.java index 159178e30..3e872b05e 100644 --- a/src/main/java/com/iexec/worker/compute/post/PostComputeService.java +++ b/src/main/java/com/iexec/worker/compute/post/PostComputeService.java @@ -38,6 +38,7 @@ import com.iexec.worker.tee.TeeServicesPropertiesService; import com.iexec.worker.workflow.WorkflowError; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; import java.io.IOException; @@ -181,6 +182,10 @@ public PostComputeResponse runTeePostCompute(final TaskDescription taskDescripti .withBinds(binds) .withDevices(teeService.getDevices()) .withNetworkMode(workerConfigService.getDockerNetworkName()); + // TDX specific config to access worker DNS from post-compute + if (taskDescription.requiresTdx() && !StringUtils.isBlank(workerConfigService.getDockerExtraHosts())) { + hostConfig.withExtraHosts(workerConfigService.getDockerExtraHosts()); + } final DockerRunRequest request = DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(chainTaskId) diff --git a/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java b/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java index 898958862..17864d0de 100644 --- a/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java +++ b/src/main/java/com/iexec/worker/compute/pre/PreComputeService.java @@ -35,6 +35,7 @@ import com.iexec.worker.tee.TeeServicesPropertiesService; import com.iexec.worker.workflow.WorkflowError; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import org.springframework.stereotype.Service; import java.time.Duration; @@ -147,6 +148,10 @@ private Integer prepareTeeInputData(final TaskDescription taskDescription) throw .withBinds(binds) .withDevices(teeService.getDevices()) .withNetworkMode(workerConfigService.getDockerNetworkName()); + // TDX specific config to access worker DNS from pre-compute + if (taskDescription.requiresTdx() && !StringUtils.isBlank(workerConfigService.getDockerExtraHosts())) { + hostConfig.withExtraHosts(workerConfigService.getDockerExtraHosts()); + } final DockerRunRequest request = DockerRunRequest.builder() .hostConfig(hostConfig) .chainTaskId(chainTaskId) diff --git a/src/main/java/com/iexec/worker/config/WorkerConfigurationService.java b/src/main/java/com/iexec/worker/config/WorkerConfigurationService.java index 52f0db493..97104903b 100644 --- a/src/main/java/com/iexec/worker/config/WorkerConfigurationService.java +++ b/src/main/java/com/iexec/worker/config/WorkerConfigurationService.java @@ -41,6 +41,7 @@ public class WorkerConfigurationService { private Integer overrideAvailableCpuCount; @Value("${worker.gpu-enabled}") + @Getter private boolean isGpuEnabled; @Value("${worker.gas-price-multiplier}") @@ -67,6 +68,10 @@ public class WorkerConfigurationService { @Getter private String dockerNetworkName; + @Value("${worker.docker-extra-hosts:}") + @Getter + private String dockerExtraHosts; + @PostConstruct private void postConstruct() { if (overrideAvailableCpuCount != null && overrideAvailableCpuCount <= 0) { @@ -75,10 +80,6 @@ private void postConstruct() { } } - public boolean isGpuEnabled() { - return isGpuEnabled; - } - public String getWorkerBaseDir() { return workerBaseDir + File.separator + workerName; } diff --git a/src/main/java/com/iexec/worker/result/ResultService.java b/src/main/java/com/iexec/worker/result/ResultService.java index 4c7ae1d02..449a5bb98 100644 --- a/src/main/java/com/iexec/worker/result/ResultService.java +++ b/src/main/java/com/iexec/worker/result/ResultService.java @@ -191,7 +191,7 @@ public String uploadResultAndGetLink(final WorkerpoolAuthorization workerpoolAut } // Cloud computing - tee - if (task.requiresSgx()) { + if (task.requiresSgx() || task.requiresTdx()) { log.info("Web2 storage, already uploaded (with tee) [chainTaskId:{}]", chainTaskId); return getWeb2ResultLink(task); } diff --git a/src/main/java/com/iexec/worker/task/TaskManagerService.java b/src/main/java/com/iexec/worker/task/TaskManagerService.java index 3ef974ac8..f21391c2f 100644 --- a/src/main/java/com/iexec/worker/task/TaskManagerService.java +++ b/src/main/java/com/iexec/worker/task/TaskManagerService.java @@ -107,12 +107,12 @@ ReplicateActionResponse start(final TaskDescription taskDescription) { } // result encryption is not supported for standard tasks - if (!taskDescription.requiresSgx() && taskDescription.getDealParams().isIexecResultEncryption()) { + if (!taskDescription.requiresSgx() && !taskDescription.requiresTdx() && taskDescription.getDealParams().isIexecResultEncryption()) { return getFailureResponseAndPrintErrors( List.of(new WorkflowError(TASK_DESCRIPTION_INVALID)), context, chainTaskId); } - if (taskDescription.requiresSgx()) { + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { // If any TEE prerequisite is not met, // then we won't be able to run the task. // So it should be aborted right now. @@ -195,7 +195,7 @@ ReplicateActionResponse downloadData(final TaskDescription taskDescription) { requireNonNull(taskDescription, "task description must not be null"); final String chainTaskId = taskDescription.getChainTaskId(); // Return early if TEE task - if (taskDescription.requiresSgx()) { + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { log.info("Dataset and input files will be downloaded by the pre-compute enclave [chainTaskId:{}]", chainTaskId); return ReplicateActionResponse.success(); } @@ -256,7 +256,7 @@ ReplicateActionResponse compute(final TaskDescription taskDescription) { List.of(new WorkflowError(APP_NOT_FOUND_LOCALLY)), context, chainTaskId); } - if (taskDescription.requiresSgx()) { + if (taskDescription.requiresSgx() || taskDescription.requiresTdx()) { final TeeService teeService = teeServicesManager.getTeeService(taskDescription.getTeeFramework()); if (!teeService.prepareTeeForTask(chainTaskId)) { return getFailureResponseAndPrintErrors( diff --git a/src/main/java/com/iexec/worker/tee/TeeService.java b/src/main/java/com/iexec/worker/tee/TeeService.java index faf0ce79e..f6ebda786 100644 --- a/src/main/java/com/iexec/worker/tee/TeeService.java +++ b/src/main/java/com/iexec/worker/tee/TeeService.java @@ -56,7 +56,7 @@ public List areTeePrerequisitesMetForTask(final String chainTaskI // If it can't be loaded, then we won't be able to run the task. smsService.getSmsClient(chainTaskId); } catch (SmsClientCreationException e) { - log.error("Couldn't get SmsClient [chainTaskId: {}]", chainTaskId, e); + log.error("Couldn't get SmsClient [chainTaskId:{}]", chainTaskId, e); return List.of(new WorkflowError(ReplicateStatusCause.UNKNOWN_SMS)); } diff --git a/src/main/java/com/iexec/worker/tee/TeeServicesManager.java b/src/main/java/com/iexec/worker/tee/TeeServicesManager.java index dbdfa4ba0..11df174c7 100644 --- a/src/main/java/com/iexec/worker/tee/TeeServicesManager.java +++ b/src/main/java/com/iexec/worker/tee/TeeServicesManager.java @@ -19,16 +19,20 @@ import com.iexec.commons.poco.tee.TeeFramework; import com.iexec.worker.tee.gramine.TeeGramineService; import com.iexec.worker.tee.scone.TeeSconeService; +import com.iexec.worker.tee.tdx.TeeTdxService; import org.springframework.stereotype.Service; @Service public class TeeServicesManager { + private final TeeTdxService teeTdxService; private final TeeSconeService teeSconeService; private final TeeGramineService teeGramineService; - public TeeServicesManager(final TeeSconeService teeSconeService, + public TeeServicesManager(final TeeTdxService teeTdxService, + final TeeSconeService teeSconeService, final TeeGramineService teeGramineService) { + this.teeTdxService = teeTdxService; this.teeSconeService = teeSconeService; this.teeGramineService = teeGramineService; } @@ -39,9 +43,9 @@ public TeeService getTeeService(final TeeFramework teeFramework) { } return switch (teeFramework) { + case TDX -> teeTdxService; case SCONE -> teeSconeService; case GRAMINE -> teeGramineService; - default -> throw new IllegalArgumentException("No TEE service defined for this TEE framework."); }; } } diff --git a/src/main/java/com/iexec/worker/tee/TeeServicesPropertiesService.java b/src/main/java/com/iexec/worker/tee/TeeServicesPropertiesService.java index 961bdd99a..f2de74c75 100644 --- a/src/main/java/com/iexec/worker/tee/TeeServicesPropertiesService.java +++ b/src/main/java/com/iexec/worker/tee/TeeServicesPropertiesService.java @@ -67,8 +67,26 @@ public TeeServicesProperties getTeeServicesProperties(final String chainTaskId) return propertiesForTask.get(chainTaskId); } + public List putTeeServicesPropertiesForTask(final String chainTaskId, final TeeServicesProperties properties) { + final List errors = new ArrayList<>(); + + final String preComputeImage = properties.getPreComputeProperties().getImage(); + final String postComputeImage = properties.getPostComputeProperties().getImage(); + errors.addAll(checkImageIsPresentOrDownload(preComputeImage, chainTaskId, "preComputeImage")); + errors.addAll(checkImageIsPresentOrDownload(postComputeImage, chainTaskId, "postComputeImage")); + + propertiesForTask.put(chainTaskId, properties); + log.info("TEE services properties storage in cache [chainTaskId:{}, contains-key:{}]", + chainTaskId, propertiesForTask.containsKey(chainTaskId)); + + return List.copyOf(errors); + } + public List retrieveTeeServicesProperties(final String chainTaskId) { final TaskDescription taskDescription = iexecHubService.getTaskDescription(chainTaskId); + if (taskDescription.requiresTdx()) { + return List.of(); + } // TODO errors could be renamed for APP enclave checks final TeeEnclaveConfiguration teeEnclaveConfiguration = taskDescription.getAppEnclaveConfiguration(); @@ -109,19 +127,7 @@ public List retrieveTeeServicesProperties(final String chainTaskI } log.info("TEE services properties received [chainTaskId:{}]", chainTaskId); - final String preComputeImage = properties.getPreComputeProperties().getImage(); - final String postComputeImage = properties.getPostComputeProperties().getImage(); - final List errors = new ArrayList<>(); - - errors.addAll(checkImageIsPresentOrDownload(preComputeImage, chainTaskId, "preComputeImage")); - errors.addAll(checkImageIsPresentOrDownload(postComputeImage, chainTaskId, "postComputeImage")); - - if (errors.isEmpty()) { - propertiesForTask.put(chainTaskId, properties); - log.info("TEE services properties storage in cache [chainTaskId:{}, contains-key:{}]", - chainTaskId, propertiesForTask.containsKey(chainTaskId)); - } - return List.copyOf(errors); + return putTeeServicesPropertiesForTask(chainTaskId, properties); } private List checkImageIsPresentOrDownload(final String image, final String chainTaskId, final String imageType) { diff --git a/src/main/java/com/iexec/worker/tee/tdx/TdxSession.java b/src/main/java/com/iexec/worker/tee/tdx/TdxSession.java new file mode 100644 index 000000000..932e946b1 --- /dev/null +++ b/src/main/java/com/iexec/worker/tee/tdx/TdxSession.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025 IEXEC BLOCKCHAIN TECH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.iexec.worker.tee.tdx; + +import java.util.List; +import java.util.Map; + +public record TdxSession(String name, String version, List services) { + public record Service(String name, String image_name, String fingerprint, Map environment) { + } +} diff --git a/src/main/java/com/iexec/worker/tee/tdx/TeeTdxService.java b/src/main/java/com/iexec/worker/tee/tdx/TeeTdxService.java new file mode 100644 index 000000000..8399dbfe7 --- /dev/null +++ b/src/main/java/com/iexec/worker/tee/tdx/TeeTdxService.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025 IEXEC BLOCKCHAIN TECH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.iexec.worker.tee.tdx; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.dockerjava.api.model.Device; +import com.iexec.common.lifecycle.purge.Purgeable; +import com.iexec.commons.poco.chain.WorkerpoolAuthorization; +import com.iexec.commons.poco.task.TaskDescription; +import com.iexec.sms.api.TeeSessionGenerationError; +import com.iexec.sms.api.TeeSessionGenerationResponse; +import com.iexec.sms.api.config.TdxServicesProperties; +import com.iexec.sms.api.config.TeeAppProperties; +import com.iexec.worker.config.WorkerConfigurationService; +import com.iexec.worker.sms.SmsService; +import com.iexec.worker.sms.TeeSessionGenerationException; +import com.iexec.worker.tee.TeeService; +import com.iexec.worker.tee.TeeServicesPropertiesService; +import jakarta.annotation.PreDestroy; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; + +@Slf4j +@Service +public class TeeTdxService extends TeeService implements Purgeable { + private final String secretProviderAgent; + private final WorkerConfigurationService workerConfigurationService; + + private final ObjectMapper mapper = new ObjectMapper(); + private final Map tdxSessions = new ConcurrentHashMap<>(); + + public TeeTdxService(final SmsService smsService, + final TeeServicesPropertiesService teeServicesPropertiesService, + @Value("${tee.tdx.secret-provider-agent}") final String secretProviderAgent, + final WorkerConfigurationService workerConfigurationService) { + super(smsService, teeServicesPropertiesService); + this.secretProviderAgent = secretProviderAgent; + this.workerConfigurationService = workerConfigurationService; + } + + @Override + public boolean isTeeEnabled() { + // FIXME add service to check TDX compatibility + return true; + } + + @Override + public void createTeeSession(final WorkerpoolAuthorization workerpoolAuthorization) throws TeeSessionGenerationException { + super.createTeeSession(workerpoolAuthorization); + final String chainTaskId = workerpoolAuthorization.getChainTaskId(); + final TeeSessionGenerationResponse teeSession = getTeeSession(chainTaskId); + try { + final String provisioningUrl = teeSession.getSecretProvisioningUrl(); + final String sessionId = teeSession.getSessionId(); + Files.createDirectories(Path.of(workerConfigurationService.getTaskBaseDir(chainTaskId))); + final String fileName = String.format("session-%s.json", chainTaskId); + final String filePath = String.format("%s/%s", workerConfigurationService.getTaskBaseDir(chainTaskId), fileName); + final Process process = new ProcessBuilder( + secretProviderAgent, "-e", provisioningUrl, "-i", sessionId, "-s", fileName, "-v", "nullverifier") + .directory(Path.of(workerConfigurationService.getTaskBaseDir(chainTaskId)).toFile()) + .start(); + final int status = process.waitFor(); + log.info("secret_provider_agent process ended [status:{}, provisioning:{}, file-path:{}]", + status, provisioningUrl, filePath); + try (final BufferedReader output = new BufferedReader(new InputStreamReader(process.getInputStream()))) { + String line; + while ((line = output.readLine()) != null) { + log.info("secret_provider_agent {}", line); + } + } + final File sessionFile = new File(filePath); + final TdxSession taskSession = mapper.readValue(sessionFile, TdxSession.class); + tdxSessions.put(chainTaskId, taskSession); + final TdxServicesProperties properties = new TdxServicesProperties( + "", // no meaning as the SMS is started with a single configuration for now + TeeAppProperties.builder().image(getService(chainTaskId, "pre-compute").findFirst().map(TdxSession.Service::image_name).orElse("")).build(), + TeeAppProperties.builder().image(getService(chainTaskId, "post-compute").findFirst().map(TdxSession.Service::image_name).orElse("")).build()); + teeServicesPropertiesService.putTeeServicesPropertiesForTask(chainTaskId, properties); + } catch (IOException e) { + log.warn("I/O error when creating TEE session for task [chainTaskId:{}]", chainTaskId, e); + throw new TeeSessionGenerationException(TeeSessionGenerationError.SECURE_SESSION_STORAGE_CALL_FAILED); + } catch (InterruptedException e) { + log.error("thread has been interrupted [chainTaskId:{}]", chainTaskId, e); + Thread.currentThread().interrupt(); + throw new TeeSessionGenerationException(TeeSessionGenerationError.SECURE_SESSION_STORAGE_CALL_FAILED); + } + } + + @Override + public boolean prepareTeeForTask(final String chainTaskId) { + return true; + } + + @Override + public List buildPreComputeDockerEnv(final TaskDescription taskDescription) { + return getDockerEnv(taskDescription.getChainTaskId(), "pre-compute"); + } + + @Override + public List buildComputeDockerEnv(final TaskDescription taskDescription) { + return getDockerEnv(taskDescription.getChainTaskId(), "app"); + } + + @Override + public List buildPostComputeDockerEnv(final TaskDescription taskDescription) { + return getDockerEnv(taskDescription.getChainTaskId(), "post-compute"); + } + + @Override + public Collection getAdditionalBindings() { + return List.of(); + } + + @Override + public List getDevices() { + return List.of(); + } + + private List getDockerEnv(final String chainTaskId, final String serviceName) { + return getService(chainTaskId, serviceName) + .findFirst() + .map(service -> service.environment().entrySet().stream() + .map(entry -> String.format("%s=%s", entry.getKey(), entry.getValue())) + .toList()).orElseGet(List::of); + } + + private Stream getService(final String chainTaskId, final String serviceName) { + final TdxSession session = tdxSessions.get(chainTaskId); + if (session == null) { + return Stream.empty(); + } + return session.services().stream() + .filter(service -> Objects.equals(serviceName, service.name())); + } + + @Override + public boolean purgeTask(final String chainTaskId) { + log.debug("purgeTask [chainTaskId:{}]", chainTaskId); + tdxSessions.remove(chainTaskId); + return super.purgeTask(chainTaskId) && !tdxSessions.containsKey(chainTaskId); + } + + @Override + @PreDestroy + public void purgeAllTasksData() { + log.info("Method purgeAllTasksData() called to perform task data cleanup."); + tdxSessions.clear(); + super.purgeAllTasksData(); + } +} diff --git a/src/main/java/com/iexec/worker/worker/WorkerService.java b/src/main/java/com/iexec/worker/worker/WorkerService.java index b091f5799..d1bbbe3a7 100644 --- a/src/main/java/com/iexec/worker/worker/WorkerService.java +++ b/src/main/java/com/iexec/worker/worker/WorkerService.java @@ -89,15 +89,17 @@ public boolean registerWorker() { log.info("Running with proxy [proxyHost:{}, proxyPort:{}]", workerConfigService.getHttpProxyHost(), workerConfigService.getHttpProxyPort()); } - WorkerModel model = WorkerModel.builder() + // FIXME add service to check TDX compatibility, use SgxService instead of TeeSconeService + final WorkerModel model = WorkerModel.builder() .name(workerConfigService.getWorkerName()) .walletAddress(workerWalletAddress) .os(workerConfigService.getOS()) .cpu(workerConfigService.getCPU()) .cpuNb(workerConfigService.getCpuCount()) .memorySize(workerConfigService.getMemorySize()) - .teeEnabled(teeSconeService.isTeeEnabled()) .gpuEnabled(workerConfigService.isGpuEnabled()) + .teeEnabled(teeSconeService.isTeeEnabled()) + .tdxEnabled(true) .build(); customCoreFeignClient.registerWorker(model); diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 24727b6a1..f7ada9680 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -51,6 +51,8 @@ management: tee: sgx: driver-mode: ${IEXEC_WORKER_SGX_DRIVER_MODE:NONE} + tdx: + secret-provider-agent: /usr/local/bin/secret_provider_agent metrics: window-size: ${IEXEC_WORKER_METRICS_WINDOW_SIZE:1000} diff --git a/src/test/java/com/iexec/worker/chain/ContributionServiceTests.java b/src/test/java/com/iexec/worker/chain/ContributionServiceTests.java index 4cdf8ea27..e58e5edd9 100644 --- a/src/test/java/com/iexec/worker/chain/ContributionServiceTests.java +++ b/src/test/java/com/iexec/worker/chain/ContributionServiceTests.java @@ -29,6 +29,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.web3j.crypto.Credentials; @@ -452,8 +454,9 @@ void getContribution() { ); } - @Test - void getContributionWithTee() { + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void getContributionWithTee(final OrderTag orderTag) { final String chainTaskId = "0x0000000000000000000000000000000000000000000000000000000000000002"; final String resultDigest = "0x0000000000000000000000000000000000000000000000000000000000000001"; @@ -465,7 +468,7 @@ void getContributionWithTee() { when(enclaveAuthorizationService. isVerifiedEnclaveSignature(anyString(), anyString(), anyString(), anyString(), anyString())) .thenReturn(true); - when(iexecHubService.getTaskDescription(chainTaskId)).thenReturn(getTaskDescription(OrderTag.TEE_SCONE)); + when(iexecHubService.getTaskDescription(chainTaskId)).thenReturn(getTaskDescription(orderTag)); final ComputedFile computedFile = ComputedFile.builder() .taskId(chainTaskId) diff --git a/src/test/java/com/iexec/worker/compute/ComputeManagerServiceTests.java b/src/test/java/com/iexec/worker/compute/ComputeManagerServiceTests.java index 51368af71..34ecd8be9 100644 --- a/src/test/java/com/iexec/worker/compute/ComputeManagerServiceTests.java +++ b/src/test/java/com/iexec/worker/compute/ComputeManagerServiceTests.java @@ -174,10 +174,11 @@ void shouldRunStandardPreCompute() { assertThat(preComputeResponse.isSuccessful()).isTrue(); } - @Test - void shouldRunTeePreCompute() { + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldRunTeePreCompute(final OrderTag orderTag) { final PreComputeResponse mockResponse = mock(PreComputeResponse.class); - final TaskDescription taskDescription = createTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); + final TaskDescription taskDescription = createTaskDescriptionBuilder(orderTag).build(); when(preComputeService.runTeePreCompute(taskDescription)).thenReturn(mockResponse); PreComputeResponse preComputeResponse = @@ -187,9 +188,10 @@ void shouldRunTeePreCompute() { .runTeePreCompute(taskDescription); } - @Test - void shouldRunTeePreComputeWithFailureResponse() { - final TaskDescription taskDescription = createTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldRunTeePreComputeWithFailureResponse(final OrderTag orderTag) { + final TaskDescription taskDescription = createTaskDescriptionBuilder(orderTag).build(); when(preComputeService.runTeePreCompute(taskDescription)).thenReturn(PreComputeResponse.builder() .exitCauses(List.of(new WorkflowError(ReplicateStatusCause.PRE_COMPUTE_DATASET_URL_MISSING))) .build()); @@ -352,9 +354,10 @@ void shouldRunStandardPostComputeWithFailureResponse(ReplicateStatusCause status .containsExactly(new WorkflowError(statusCause)); } - @Test - void shouldRunTeePostCompute() { - final TaskDescription taskDescription = createTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldRunTeePostCompute(final OrderTag orderTag) { + final TaskDescription taskDescription = createTaskDescriptionBuilder(orderTag).build(); PostComputeResponse expectedDockerRunResponse = PostComputeResponse.builder() .stdout(dockerLogs.getStdout()) @@ -377,9 +380,10 @@ void shouldRunTeePostCompute() { verify(resultService).saveResultInfo(any(), any()); } - @Test - void shouldRunTeePostComputeWithFailureResponse() { - final TaskDescription taskDescription = createTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldRunTeePostComputeWithFailureResponse(final OrderTag orderTag) { + final TaskDescription taskDescription = createTaskDescriptionBuilder(orderTag).build(); PostComputeResponse expectedDockerRunResponse = PostComputeResponse.builder() .exitCauses(List.of(new WorkflowError(ReplicateStatusCause.APP_COMPUTE_FAILED))) diff --git a/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java b/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java index b936b272d..0a89a3b43 100644 --- a/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java +++ b/src/test/java/com/iexec/worker/compute/app/AppComputeServiceTests.java @@ -37,6 +37,8 @@ import com.iexec.worker.tee.TeeServicesManager; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; @@ -131,9 +133,10 @@ void shouldRunCompute() { ); } - @Test - void shouldRunComputeWithTeeAndConnectAppToLas() { - final TaskDescription taskDescription = getTaskDescriptionBuilder(OrderTag.TEE_SCONE) + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldRunComputeWithTeeAndConnectAppToLas(final OrderTag orderTag) { + final TaskDescription taskDescription = getTaskDescriptionBuilder(orderTag) .appEnclaveConfiguration(TeeEnclaveConfiguration.builder().heapSize(HEAP_SIZE).build()) .build(); when(teeServicesManager.getTeeService(any())).thenReturn(teeMockedService); diff --git a/src/test/java/com/iexec/worker/task/TaskManagerServiceTests.java b/src/test/java/com/iexec/worker/task/TaskManagerServiceTests.java index 8e2db19d9..f59f6d667 100644 --- a/src/test/java/com/iexec/worker/task/TaskManagerServiceTests.java +++ b/src/test/java/com/iexec/worker/task/TaskManagerServiceTests.java @@ -170,42 +170,43 @@ void shouldNotStartSinceStandardTaskWithEncryption() { .build(); when(contributionService.getCannotContributeStatusCause(CHAIN_TASK_ID)) .thenReturn(emptyCauses); - ReplicateActionResponse actionResponse = taskManagerService.start(taskDescription); + final ReplicateActionResponse actionResponse = taskManagerService.start(taskDescription); assertThat(actionResponse.isSuccess()).isFalse(); assertThat(actionResponse.getDetails().getCause()).isEqualTo(TASK_DESCRIPTION_INVALID); } - @Test - void shouldStartTeeTask() { + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldStartTeeTask(final OrderTag orderTag) { when(contributionService.getCannotContributeStatusCause(CHAIN_TASK_ID)) .thenReturn(emptyCauses); when(teeServicesManager.getTeeService(any())).thenReturn(teeMockedService); when(teeMockedService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)).thenReturn(emptyCauses); - ReplicateActionResponse actionResponse = - taskManagerService.start(getTaskDescriptionBuilder(OrderTag.TEE_SCONE).build()); + final ReplicateActionResponse actionResponse = taskManagerService.start(getTaskDescriptionBuilder(orderTag).build()); assertThat(actionResponse.isSuccess()).isTrue(); verifyNoInteractions(iexecHubService); } - @Test - void shouldNotStartSinceTeePrerequisitesAreNotMet() { + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldNotStartSinceTeePrerequisitesAreNotMet(final OrderTag orderTag) { when(contributionService.getCannotContributeStatusCause(CHAIN_TASK_ID)) .thenReturn(emptyCauses); when(teeServicesManager.getTeeService(any())).thenReturn(teeMockedService); when(teeMockedService.areTeePrerequisitesMetForTask(CHAIN_TASK_ID)) .thenReturn(List.of(new WorkflowError(TEE_NOT_SUPPORTED))); - ReplicateActionResponse actionResponse = - taskManagerService.start(getTaskDescriptionBuilder(OrderTag.TEE_SCONE).build()); + final ReplicateActionResponse actionResponse = taskManagerService.start(getTaskDescriptionBuilder(orderTag).build()); assertThat(actionResponse.isSuccess()).isFalse(); assertThat(actionResponse.getDetails().getCause()).isEqualTo(TEE_NOT_SUPPORTED); } - @Test - void shouldNotStartSinceTeeSessionCreationFailed() throws TeeSessionGenerationException { + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldNotStartSinceTeeSessionCreationFailed(final OrderTag orderTag) throws TeeSessionGenerationException { when(contributionService.getCannotContributeStatusCause(CHAIN_TASK_ID)) .thenReturn(emptyCauses); when(teeServicesManager.getTeeService(any())).thenReturn(teeMockedService); @@ -214,8 +215,7 @@ void shouldNotStartSinceTeeSessionCreationFailed() throws TeeSessionGenerationEx doThrow(new TeeSessionGenerationException(TeeSessionGenerationError.UNKNOWN_ISSUE)) .when(teeMockedService).createTeeSession(any()); - ReplicateActionResponse actionResponse = - taskManagerService.start(getTaskDescriptionBuilder(OrderTag.TEE_SCONE).build()); + final ReplicateActionResponse actionResponse = taskManagerService.start(getTaskDescriptionBuilder(orderTag).build()); assertThat(actionResponse.isSuccess()).isFalse(); assertThat(actionResponse.getDetails().getCause()).isEqualTo(TEE_SESSION_GENERATION_UNKNOWN_ISSUE); @@ -413,11 +413,11 @@ void shouldDownloadInputFilesAndNotDataset() throws Exception { // with dataset + with input files + TEE task - @Test - void shouldNotDownloadDataWithDatasetUriForTeeTaskAndReturnSuccess() { - final TaskDescription taskDescription = getTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); - final ReplicateActionResponse actionResponse = - taskManagerService.downloadData(taskDescription); + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldNotDownloadDataWithDatasetUriForTeeTaskAndReturnSuccess(final OrderTag orderTag) { + final TaskDescription taskDescription = getTaskDescriptionBuilder(orderTag).build(); + final ReplicateActionResponse actionResponse = taskManagerService.downloadData(taskDescription); assertThat(actionResponse.isSuccess()).isTrue(); verifyNoInteractions(dataService); } @@ -652,9 +652,10 @@ void shouldComputeStandardTask() { verifyNoInteractions(teeServicesManager, resultService); } - @Test - void shouldComputeTeeTask() { - final TaskDescription taskDescription = getTaskDescriptionBuilder(OrderTag.TEE_SCONE).build(); + @ParameterizedTest + @EnumSource(value = OrderTag.class, names = {"TEE_SCONE", "TEE_TDX"}) + void shouldComputeTeeTask(final OrderTag orderTag) { + final TaskDescription taskDescription = getTaskDescriptionBuilder(orderTag).build(); when(contributionService.getCannotContributeStatusCause(CHAIN_TASK_ID)) .thenReturn(emptyCauses); diff --git a/src/test/java/com/iexec/worker/tee/TeeServicesManagerTests.java b/src/test/java/com/iexec/worker/tee/TeeServicesManagerTests.java index dae18b936..7cfdf85f6 100644 --- a/src/test/java/com/iexec/worker/tee/TeeServicesManagerTests.java +++ b/src/test/java/com/iexec/worker/tee/TeeServicesManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2022-2023 IEXEC BLOCKCHAIN TECH + * Copyright 2022-2025 IEXEC BLOCKCHAIN TECH * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,21 +19,26 @@ import com.iexec.commons.poco.tee.TeeFramework; import com.iexec.worker.tee.gramine.TeeGramineService; import com.iexec.worker.tee.scone.TeeSconeService; -import org.junit.jupiter.api.BeforeEach; +import com.iexec.worker.tee.tdx.TeeTdxService; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.InjectMocks; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; import java.util.stream.Stream; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +@ExtendWith(MockitoExtension.class) class TeeServicesManagerTests { + @Mock + TeeTdxService teeTdxService; @Mock TeeSconeService teeSconeService; @Mock @@ -42,13 +47,9 @@ class TeeServicesManagerTests { @InjectMocks TeeServicesManager teeServicesManager; - @BeforeEach - void init() { - MockitoAnnotations.openMocks(this); - } - static Stream teeServices() { return Stream.of( + Arguments.of(TeeFramework.TDX, TeeTdxService.class), Arguments.of(TeeFramework.SCONE, TeeSconeService.class), Arguments.of(TeeFramework.GRAMINE, TeeGramineService.class) ); diff --git a/src/test/java/com/iexec/worker/tee/TeeServicesPropertiesServiceTests.java b/src/test/java/com/iexec/worker/tee/TeeServicesPropertiesServiceTests.java index 3a1290537..5ab0dcecf 100644 --- a/src/test/java/com/iexec/worker/tee/TeeServicesPropertiesServiceTests.java +++ b/src/test/java/com/iexec/worker/tee/TeeServicesPropertiesServiceTests.java @@ -129,6 +129,14 @@ void shouldRetrieveTeeServicesConfiguration() { verify(dockerClient, never()).pullImage(POST_COMPUTE_IMAGE); } + @Test + void shouldNotRetrieveTeeServicesConfigurationForTdx() { + final TaskDescription taskDescription = taskDescriptionBuilder.teeFramework(TeeFramework.TDX).build(); + when(iexecHubService.getTaskDescription(CHAIN_TASK_ID)).thenReturn(taskDescription); + assertThat(teeServicesPropertiesService.retrieveTeeServicesProperties(CHAIN_TASK_ID)).isEmpty(); + verifyNoInteractions(smsService, smsClient, dockerService, dockerClient); + } + @Test void shouldNotRetrieveTeeServicesConfigurationWhenTeeEnclaveConfigurationIsNull() { final TaskDescription taskDescription = taskDescriptionBuilder.appEnclaveConfiguration(null).build(); diff --git a/src/test/java/com/iexec/worker/tee/tdx/TeeTdxServiceTests.java b/src/test/java/com/iexec/worker/tee/tdx/TeeTdxServiceTests.java new file mode 100644 index 000000000..f47686b0d --- /dev/null +++ b/src/test/java/com/iexec/worker/tee/tdx/TeeTdxServiceTests.java @@ -0,0 +1,240 @@ +/* + * Copyright 2025 IEXEC BLOCKCHAIN TECH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.iexec.worker.tee.tdx; + +import com.iexec.commons.poco.chain.WorkerpoolAuthorization; +import com.iexec.commons.poco.task.TaskDescription; +import com.iexec.sms.api.TeeSessionGenerationError; +import com.iexec.sms.api.TeeSessionGenerationResponse; +import com.iexec.worker.config.WorkerConfigurationService; +import com.iexec.worker.sms.SmsService; +import com.iexec.worker.sms.TeeSessionGenerationException; +import com.iexec.worker.tee.TeeServicesPropertiesService; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.test.util.ReflectionTestUtils; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class TeeTdxServiceTests { + private static final String CHAIN_TASK_ID = "0xe7610523051210870457895cd72959aeef5872a8d328db7557ff3489834c3ce6"; + private static final String TDX_SESSION_ID = "01234567890000" + CHAIN_TASK_ID; + private static final String TDX_SESSION_VERSION = "0.1.0"; + private static final String DOCKER_ENV_VAR_NAME = "ENV"; + private static final String DOCKER_ENV_VAR_VALUE = "VALUE"; + + @Mock + private SmsService smsService; + @Mock + private TeeServicesPropertiesService teeServicesPropertiesService; + @Mock + private WorkerConfigurationService workerConfigurationService; + @InjectMocks + private TeeTdxService teeTdxService; + + @TempDir + Path taskBaseDir; + + final TaskDescription taskDescription = TaskDescription.builder().chainTaskId(CHAIN_TASK_ID).build(); + + @Test + void shouldBeTeeEnabled() { + assertThat(teeTdxService.isTeeEnabled()).isTrue(); + } + + // region createTeeSession + @Test + void shouldCreateTeeSession() throws TeeSessionGenerationException { + ReflectionTestUtils.setField(teeTdxService, "secretProviderAgent", this.getClass().getClassLoader().getResource("secret_provider_agent").getFile()); + final WorkerpoolAuthorization authorization = WorkerpoolAuthorization.builder().chainTaskId(CHAIN_TASK_ID).build(); + when(smsService.createTeeSession(authorization)) + .thenReturn(new TeeSessionGenerationResponse("sessionId", "secretProvisioningUrl")); + when(workerConfigurationService.getTaskBaseDir(CHAIN_TASK_ID)).thenReturn(taskBaseDir.toString()); + assertThatNoException().isThrownBy(() -> teeTdxService.createTeeSession(authorization)); + } + + @Test + void shouldFailToCreateTeeSession() throws TeeSessionGenerationException { + ReflectionTestUtils.setField(teeTdxService, "secretProviderAgent", "/tmp/not-found/secret_provider_agent"); + final WorkerpoolAuthorization authorization = WorkerpoolAuthorization.builder().chainTaskId(CHAIN_TASK_ID).build(); + when(smsService.createTeeSession(authorization)) + .thenReturn(new TeeSessionGenerationResponse("sessionId", "secretProvisioningUrl")); + when(workerConfigurationService.getTaskBaseDir(CHAIN_TASK_ID)).thenReturn(taskBaseDir.toString()); + assertThatThrownBy(() -> teeTdxService.createTeeSession(authorization)) + .isInstanceOf(TeeSessionGenerationException.class) + .hasFieldOrPropertyWithValue("TeeSessionGenerationError", TeeSessionGenerationError.SECURE_SESSION_STORAGE_CALL_FAILED); + } + // endregion + + @Test + void shouldPrepareTeeForTask() { + assertThat(teeTdxService.prepareTeeForTask(CHAIN_TASK_ID)).isTrue(); + } + + // region buildPreComputeDockerEnv + @Test + void shouldBuildEmptyPreComputeDockerEnvWhenNoSession() { + assertThat(teeTdxService.buildPreComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildEmptyPreComputeDockerEnvWhenEmptyMap() { + final List sessionServices = List.of(new TdxSession.Service( + "pre-compute", "image_name", "fingerprint", Map.of())); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildPreComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildPreComputeDockerEnv() { + final List sessionServices = List.of(new TdxSession.Service( + "pre-compute", "image_name", "fingerprint", Map.of(DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE))); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildPreComputeDockerEnv(taskDescription)) + .containsExactly(String.format("%s=%s", DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE)); + } + // endregion + + // region buildComputeDockerEnv + @Test + void shouldBuildEmptyComputeDockerEnvWhenNoSession() { + assertThat(teeTdxService.buildComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildEmptyComputeDockerEnvWhenEmptyMap() { + final List sessionServices = List.of(new TdxSession.Service( + "app", "image_name", "fingerprint", Map.of())); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildComputeDockerEnv() { + final List sessionServices = List.of(new TdxSession.Service( + "app", "image_name", "fingerprint", Map.of(DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE))); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildComputeDockerEnv(taskDescription)) + .containsExactly(String.format("%s=%s", DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE)); + } + // endregion + + // region buildPostComputeDockerEnv + @Test + void shouldBuildEmptyPostComputeDockerEnvWhenNoSession() { + assertThat(teeTdxService.buildPostComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildEmptyPostComputeDockerEnvWhenEmptyMap() { + final List sessionServices = List.of(new TdxSession.Service( + "post-compute", "image_name", "fingerprint", Map.of())); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildPostComputeDockerEnv(taskDescription)).isEmpty(); + } + + @Test + void shouldBuildPostComputeDockerEnv() { + final List sessionServices = List.of(new TdxSession.Service( + "post-compute", "image_name", "fingerprint", Map.of(DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE))); + final Map sessions = Map.of(CHAIN_TASK_ID, new TdxSession(TDX_SESSION_ID, TDX_SESSION_VERSION, sessionServices)); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", sessions); + assertThat(teeTdxService.buildPostComputeDockerEnv(taskDescription)) + .containsExactly(String.format("%s=%s", DOCKER_ENV_VAR_NAME, DOCKER_ENV_VAR_VALUE)); + } + // endregion + + @Test + void shouldNotRequireAdditionalBindings() { + assertThat(teeTdxService.getAdditionalBindings()).isEmpty(); + } + + @Test + void shouldNotRequireDevices() { + assertThat(teeTdxService.getDevices()).isEmpty(); + } + + // region TEE sessions cache + private void prefillTeeSessionsCache(final Map teeSessions, + final Map tdxSessions) { + teeSessions.put("taskId1", new TeeSessionGenerationResponse("sessionId1", "sessionUrl1")); + teeSessions.put("taskId2", new TeeSessionGenerationResponse("sessionId2", "sessionUrl2")); + ReflectionTestUtils.setField(teeTdxService, "teeSessions", teeSessions); + tdxSessions.put("taskId1", new TdxSession("sessionId1", TDX_SESSION_VERSION, List.of())); + tdxSessions.put("taskId2", new TdxSession("sessionId2", TDX_SESSION_VERSION, List.of())); + ReflectionTestUtils.setField(teeTdxService, "tdxSessions", tdxSessions); + } + + @Test + void shouldNotModifyCacheWhenNoSessionInCache() { + final Map teeSessions = new ConcurrentHashMap<>(); + final Map tdxSessions = new ConcurrentHashMap<>(); + prefillTeeSessionsCache(teeSessions, tdxSessions); + teeTdxService.purgeTask("taskId3"); + assertThat(teeSessions) + .usingRecursiveComparison() + .isEqualTo(Map.of( + "taskId1", new TeeSessionGenerationResponse("sessionId1", "sessionUrl1"), + "taskId2", new TeeSessionGenerationResponse("sessionId2", "sessionUrl2"))); + assertThat(tdxSessions) + .usingRecursiveComparison() + .isEqualTo(Map.of( + "taskId1", new TdxSession("sessionId1", TDX_SESSION_VERSION, List.of()), + "taskId2", new TdxSession("sessionId2", TDX_SESSION_VERSION, List.of()))); + } + + @Test + void shouldRemoveTeeSessionFromCache() { + final Map teeSessions = new ConcurrentHashMap<>(); + final Map tdxSessions = new ConcurrentHashMap<>(); + prefillTeeSessionsCache(teeSessions, tdxSessions); + teeTdxService.purgeTask("taskId1"); + assertThat(teeSessions) + .usingRecursiveComparison() + .isEqualTo(Map.of("taskId2", new TeeSessionGenerationResponse("sessionId2", "sessionUrl2"))); + assertThat(tdxSessions) + .usingRecursiveComparison() + .isEqualTo(Map.of("taskId2", new TdxSession("sessionId2", TDX_SESSION_VERSION, List.of()))); + } + + @Test + void shouldRemoveAllTeeSessionsFromCache() { + final Map teeSessions = new ConcurrentHashMap<>(); + final Map tdxSessions = new ConcurrentHashMap<>(); + prefillTeeSessionsCache(teeSessions, tdxSessions); + teeTdxService.purgeAllTasksData(); + assertThat(teeSessions).isEmpty(); + assertThat(tdxSessions).isEmpty(); + } + // end region +} diff --git a/src/test/resources/secret_provider_agent b/src/test/resources/secret_provider_agent new file mode 100755 index 000000000..6cc7e847c --- /dev/null +++ b/src/test/resources/secret_provider_agent @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +echo retrieving session +echo '{"name":"", "version":"0.1.0", "services":[]}' > session-0xe7610523051210870457895cd72959aeef5872a8d328db7557ff3489834c3ce6.json