diff --git a/src/main/java/dev/dbos/transact/context/DBOSContext.java b/src/main/java/dev/dbos/transact/context/DBOSContext.java index 1af8d510..b1d3935e 100644 --- a/src/main/java/dev/dbos/transact/context/DBOSContext.java +++ b/src/main/java/dev/dbos/transact/context/DBOSContext.java @@ -1,11 +1,19 @@ package dev.dbos.transact.context; public class DBOSContext { - private String workflowId; + private volatile String workflowId; private String user; - private String functionId; + private volatile int functionId; private String stepId; + public DBOSContext() { + + } + public DBOSContext(String workflowId, int functionId) { + this.workflowId = workflowId; + this.functionId = functionId ; + } + public String getWorkflowId() { return workflowId; } @@ -22,12 +30,12 @@ public void setUser(String user) { this.user = user; } - public String getFunctionId() { + public int getFunctionId() { return functionId; } - public void setFunctionId(String functionId) { - this.functionId = functionId; + public int getAndIncrementFunctionId() { + return functionId++; } public String getStepId() { @@ -37,5 +45,9 @@ public String getStepId() { public void setStepId(String stepId) { this.stepId = stepId; } + + public DBOSContext copy() { + return new DBOSContext(workflowId, functionId); + } } diff --git a/src/main/java/dev/dbos/transact/context/DBOSContextHolder.java b/src/main/java/dev/dbos/transact/context/DBOSContextHolder.java index 4626b9ae..aec81af4 100644 --- a/src/main/java/dev/dbos/transact/context/DBOSContextHolder.java +++ b/src/main/java/dev/dbos/transact/context/DBOSContextHolder.java @@ -1,7 +1,11 @@ package dev.dbos.transact.context; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class DBOSContextHolder { private static final ThreadLocal contextHolder = ThreadLocal.withInitial(DBOSContext::new); + private static Logger logger = LoggerFactory.getLogger(DBOSContextHolder.class); public static DBOSContext get() { return contextHolder.get(); @@ -9,10 +13,12 @@ public static DBOSContext get() { public static void clear() { contextHolder.remove(); + logger.debug("context cleared for thread " + Thread.currentThread().getId()); } public static void set(DBOSContext context) { contextHolder.set(context); + logger.debug("context set for thread " + Thread.currentThread().getId()); } } diff --git a/src/main/java/dev/dbos/transact/context/SetWorkflowID.java b/src/main/java/dev/dbos/transact/context/SetWorkflowID.java index 5faf3592..70755e5a 100644 --- a/src/main/java/dev/dbos/transact/context/SetWorkflowID.java +++ b/src/main/java/dev/dbos/transact/context/SetWorkflowID.java @@ -11,8 +11,7 @@ public SetWorkflowID(String workflowId) { @Override public void close() { - DBOSContext context = DBOSContextHolder.get(); - context.setWorkflowId(previousWorkflowId); + DBOSContextHolder.clear(); } } diff --git a/src/main/java/dev/dbos/transact/database/StepsDAO.java b/src/main/java/dev/dbos/transact/database/StepsDAO.java new file mode 100644 index 00000000..2b8992df --- /dev/null +++ b/src/main/java/dev/dbos/transact/database/StepsDAO.java @@ -0,0 +1,212 @@ +package dev.dbos.transact.database; + +import dev.dbos.transact.Constants; +import dev.dbos.transact.exceptions.UnExpectedStepException; +import dev.dbos.transact.exceptions.WorkflowCancelledException; +import dev.dbos.transact.exceptions.DBOSWorkflowConflictException; +import dev.dbos.transact.json.JSONUtil; +import dev.dbos.transact.workflow.StepInfo; +import dev.dbos.transact.workflow.WorkflowState; +import dev.dbos.transact.workflow.internal.StepResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.sql.DataSource; +import java.sql.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public class StepsDAO { + + private Logger logger = LoggerFactory.getLogger(StepsDAO.class); + private DataSource dataSource ; + + StepsDAO(DataSource dataSource) { + this.dataSource = dataSource; + } + + public void recordStepResultTxn(StepResult result) throws SQLException + { + + String sql = String.format( + "INSERT INTO %s.operation_outputs (workflow_uuid, function_id, function_name, output, error) " + + "VALUES (?, ?, ?, ?, ?)", + Constants.DB_SCHEMA + ); + + try (Connection connection = dataSource.getConnection() ; + PreparedStatement pstmt = connection.prepareStatement(sql)) { + int paramIdx = 1; + pstmt.setString(paramIdx++, result.getWorkflowId()); + pstmt.setInt(paramIdx++, result.getFunctionId()); + pstmt.setString(paramIdx++, result.getFunctionName()); + + if (result.getOutput() != null) { + pstmt.setString(paramIdx++, result.getOutput()); + } else { + pstmt.setNull(paramIdx++, Types.LONGVARCHAR); + } + + if (result.getError() != null) { + pstmt.setString(paramIdx++, result.getError()); + } else { + pstmt.setNull(paramIdx++, Types.LONGVARCHAR); + } + + pstmt.executeUpdate(); + + } catch (SQLException e) { + if ("23505".equals(e.getSQLState())) { + throw new DBOSWorkflowConflictException(result.getWorkflowId(), + String.format("Workflow %s already exists", result.getWorkflowId())); + } else { + throw e; + } + } + } + + /** + * Checks the execution status and output of a specific operation within a workflow. + * This method corresponds to Python's '_check_operation_execution_txn'. + * + * @param workflowId The UUID of the workflow. + * @param functionId The ID of the function/operation. + * @param functionName The expected name of the function/operation. + * @param connection The active JDBC connection (corresponding to Python's 'conn: sa.Connection'). + * @return A {@link StepResult} object if the operation has completed, otherwise {@code null}. + * @throws IllegalStateException If the workflow does not exist in the status table. + * @throws WorkflowCancelledException If the workflow is in a cancelled status. + * @throws UnExpectedStepException If the recorded function name for the operation does not match the provided name. + * @throws SQLException For other database access errors. + */ + public StepResult checkStepExecutionTxn( + String workflowId, + int functionId, + String functionName, + Connection connection + ) throws SQLException, IllegalStateException, WorkflowCancelledException, UnExpectedStepException { + + String workflowStatusSql = String.format( + "SELECT status FROM %s.workflow_status WHERE workflow_uuid = ?", + Constants.DB_SCHEMA + ); + + String workflowStatus = null; + try (PreparedStatement pstmt = connection.prepareStatement(workflowStatusSql)) { + pstmt.setString(1, workflowId); + try (ResultSet rs = pstmt.executeQuery()) { + if (rs.next()) { + workflowStatus = rs.getString("status"); + } + } + } + + if (workflowStatus == null) { + throw new IllegalStateException(String.format("Error: Workflow %s does not exist", workflowId)); + } + + if (Objects.equals(workflowStatus, WorkflowState.CANCELLED.name())) { + throw new WorkflowCancelledException( + String.format("Workflow %s is cancelled. Aborting function.", workflowId) + ); + } + + String operationOutputSql = String.format( + "SELECT output, error, function_name " + + "FROM %s.operation_outputs " + + "WHERE workflow_uuid = ? AND function_id = ?", + Constants.DB_SCHEMA + ); + + StepResult recordedResult = null; + String recordedFunctionName = null; + + try (PreparedStatement pstmt = connection.prepareStatement(operationOutputSql)) { + pstmt.setString(1, workflowId); + pstmt.setInt(2, functionId); + try (ResultSet rs = pstmt.executeQuery()) { + if (rs.next()) { // Check if any operation output row exists + String output = rs.getString("output"); + String error = rs.getString("error"); + recordedFunctionName = rs.getString("function_name"); + recordedResult = new StepResult(workflowId, functionId, recordedFunctionName, output, error); + } + } + } + + if (recordedResult == null) { + return null; + } + + if (!Objects.equals(functionName, recordedFunctionName)) { + throw new UnExpectedStepException( + workflowId, + functionId, + functionName, + recordedFunctionName + ); + } + + return recordedResult; + } + + public List listWorkflowSteps(String workflowId) throws SQLException { + String sqlTemplate = "SELECT function_id, function_name, output, error, child_workflow_id " + + "FROM %s.operation_outputs " + + "WHERE workflow_uuid = ? " + + "ORDER BY function_id;"; + final String sql = String.format(sqlTemplate, Constants.DB_SCHEMA); + System.out.println(sql); + + + List steps = new ArrayList<>(); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement stmt = connection.prepareStatement(sql)) { + + stmt.setString(1, workflowId); + + try (ResultSet rs = stmt.executeQuery()) { + + while (rs.next()) { + int functionId = rs.getInt("function_id"); + String functionName = rs.getString("function_name"); + String outputData = rs.getString("output"); + String errorData = rs.getString("error"); + String childWorkflowId = rs.getString("child_workflow_id"); + System.out.println(functionId); + + // Deserialize output if present + Object output = null; + if (outputData != null) { + try { + output = JSONUtil.deserialize(outputData); + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize output for function " + functionId, e); + } + } + + // Deserialize error if present + Exception error = null; + if (errorData != null) { + try { + // TODO error = JSONUtil.deserialize(errorData); + error = new Exception(errorData) ; + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize error for function " + functionId, e); + } + } + + steps.add(new StepInfo(functionId, functionName, output, error, childWorkflowId)); + } + } + } catch (SQLException e) { + throw new SQLException("Failed to retrieve workflow steps for workflow: " + workflowId, e); + } + + return steps; + } +} + + diff --git a/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/src/main/java/dev/dbos/transact/database/SystemDatabase.java index 52d12156..eb5a68fa 100644 --- a/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -8,9 +8,11 @@ import dev.dbos.transact.exceptions.*; import dev.dbos.transact.json.JSONUtil; import dev.dbos.transact.workflow.ListWorkflowsInput; +import dev.dbos.transact.workflow.StepInfo; import dev.dbos.transact.workflow.WorkflowState; import dev.dbos.transact.workflow.WorkflowStatus; import dev.dbos.transact.workflow.internal.InsertWorkflowResult; +import dev.dbos.transact.workflow.internal.StepResult; import dev.dbos.transact.workflow.internal.WorkflowStatusInternal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +30,7 @@ public class SystemDatabase { private DBOSConfig config ; private static SystemDatabase instance ; private DataSource dataSource ; + private StepsDAO stepsDAO ; private SystemDatabase(DBOSConfig cfg) { config = cfg ; @@ -40,6 +43,7 @@ private SystemDatabase(DBOSConfig cfg) { } createDataSource(dbName); + stepsDAO = new StepsDAO(dataSource) ; } public static synchronized void initialize(DBOSConfig cfg) { @@ -424,7 +428,7 @@ public void updateWorkflowStatus( } } -public WorkflowStatus getWorkflow(String workflowId) { + public WorkflowStatus getWorkflow(String workflowId) { try { ListWorkflowsInput input = new ListWorkflowsInput(); @@ -438,214 +442,268 @@ public WorkflowStatus getWorkflow(String workflowId) { } throw new NonExistentWorkflowException(workflowId) ; -} - -public List listWorkflows(ListWorkflowsInput input) throws SQLException { + } - List workflows = new ArrayList<>(); + public List listWorkflows(ListWorkflowsInput input) throws SQLException { - StringBuilder sqlBuilder = new StringBuilder(); - List parameters = new ArrayList<>(); + List workflows = new ArrayList<>(); - // Start building the SELECT clause. The order of columns here is critical - // for mapping to the WorkflowStatus fields by index later in the ResultSet. - sqlBuilder.append("SELECT workflow_uuid, status, name, recovery_attempts, " + - "config_name, class_name, authenticated_user, authenticated_roles, " + - "assumed_role, queue_name, executor_id, created_at, updated_at, " + - "application_version, application_id, inputs, output, error, " + - "workflow_deadline_epoch_ms, workflow_timeout_ms "); + StringBuilder sqlBuilder = new StringBuilder(); + List parameters = new ArrayList<>(); - sqlBuilder.append(String.format("FROM %s.workflow_status ", Constants.DB_SCHEMA)); + // Start building the SELECT clause. The order of columns here is critical + // for mapping to the WorkflowStatus fields by index later in the ResultSet. + sqlBuilder.append("SELECT workflow_uuid, status, name, recovery_attempts, " + + "config_name, class_name, authenticated_user, authenticated_roles, " + + "assumed_role, queue_name, executor_id, created_at, updated_at, " + + "application_version, application_id, inputs, output, error, " + + "workflow_deadline_epoch_ms, workflow_timeout_ms "); - // --- WHERE Clauses --- - StringJoiner whereConditions = new StringJoiner(" AND "); + sqlBuilder.append(String.format("FROM %s.workflow_status ", Constants.DB_SCHEMA)); - if (input.getWorkflowName() != null) { - whereConditions.add("name = ?"); - parameters.add(input.getWorkflowName()); - } - if (input.getAuthenticatedUser() != null) { - whereConditions.add("authenticated_user = ?"); - parameters.add(input.getAuthenticatedUser()); - } - if (input.getStartTime() != null) { - whereConditions.add("created_at >= ?"); - // Convert OffsetDateTime to epoch milliseconds for comparison with DB column - parameters.add(input.getStartTime().toInstant().toEpochMilli()); - } - if (input.getEndTime() != null) { - whereConditions.add("created_at <= ?"); - // Convert OffsetDateTime to epoch milliseconds for comparison with DB column - parameters.add(input.getEndTime().toInstant().toEpochMilli()); - } - if (input.getStatus() != null) { - whereConditions.add("status = ?"); - parameters.add(input.getStatus()); - } - if (input.getApplicationVersion() != null) { - whereConditions.add("application_version = ?"); - parameters.add(input.getApplicationVersion()); - } - if (input.getWorkflowIDs() != null && !input.getWorkflowIDs().isEmpty()) { - // Handle IN clause: dynamically generate ? for each ID - StringJoiner inClausePlaceholders = new StringJoiner(", ", "(", ")"); - for (String id : input.getWorkflowIDs()) { - inClausePlaceholders.add("?"); - parameters.add(id); - } - whereConditions.add("workflow_uuid IN " + inClausePlaceholders.toString()); - } - if (input.getWorkflowIdPrefix() != null) { - whereConditions.add("workflow_uuid LIKE ?"); - // Append wildcard directly to the parameter value - parameters.add(input.getWorkflowIdPrefix() + "%"); - } + // --- WHERE Clauses --- + StringJoiner whereConditions = new StringJoiner(" AND "); - // Only append WHERE keyword if there are actual conditions - if (whereConditions.length() > 0) { - sqlBuilder.append(" WHERE ").append(whereConditions.toString()); - } + if (input.getWorkflowName() != null) { + whereConditions.add("name = ?"); + parameters.add(input.getWorkflowName()); + } + if (input.getAuthenticatedUser() != null) { + whereConditions.add("authenticated_user = ?"); + parameters.add(input.getAuthenticatedUser()); + } + if (input.getStartTime() != null) { + whereConditions.add("created_at >= ?"); + // Convert OffsetDateTime to epoch milliseconds for comparison with DB column + parameters.add(input.getStartTime().toInstant().toEpochMilli()); + } + if (input.getEndTime() != null) { + whereConditions.add("created_at <= ?"); + // Convert OffsetDateTime to epoch milliseconds for comparison with DB column + parameters.add(input.getEndTime().toInstant().toEpochMilli()); + } + if (input.getStatus() != null) { + whereConditions.add("status = ?"); + parameters.add(input.getStatus()); + } + if (input.getApplicationVersion() != null) { + whereConditions.add("application_version = ?"); + parameters.add(input.getApplicationVersion()); + } + if (input.getWorkflowIDs() != null && !input.getWorkflowIDs().isEmpty()) { + // Handle IN clause: dynamically generate ? for each ID + StringJoiner inClausePlaceholders = new StringJoiner(", ", "(", ")"); + for (String id : input.getWorkflowIDs()) { + inClausePlaceholders.add("?"); + parameters.add(id); + } + whereConditions.add("workflow_uuid IN " + inClausePlaceholders.toString()); + } + if (input.getWorkflowIdPrefix() != null) { + whereConditions.add("workflow_uuid LIKE ?"); + // Append wildcard directly to the parameter value + parameters.add(input.getWorkflowIdPrefix() + "%"); + } - // --- ORDER BY Clause --- - sqlBuilder.append(" ORDER BY created_at "); - if (input.getSortDesc() != null && input.getSortDesc()) { - sqlBuilder.append("DESC"); - } else { - sqlBuilder.append("ASC"); - } + // Only append WHERE keyword if there are actual conditions + if (whereConditions.length() > 0) { + sqlBuilder.append(" WHERE ").append(whereConditions.toString()); + } - // --- LIMIT and OFFSET Clauses --- - if (input.getLimit() != null) { - sqlBuilder.append(" LIMIT ?"); - parameters.add(input.getLimit()); - } - if (input.getOffset() != null) { - sqlBuilder.append(" OFFSET ?"); - parameters.add(input.getOffset()); - } + // --- ORDER BY Clause --- + sqlBuilder.append(" ORDER BY created_at "); + if (input.getSortDesc() != null && input.getSortDesc()) { + sqlBuilder.append("DESC"); + } else { + sqlBuilder.append("ASC"); + } - try (Connection connection = dataSource.getConnection(); - PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { - - for (int i = 0; i < parameters.size(); i++) { - - Object param = parameters.get(i); - if (param instanceof String) { - pstmt.setString(i + 1, (String) param); - } else if (param instanceof Long) { - pstmt.setLong(i + 1, (Long) param); - } else if (param instanceof Integer) { - pstmt.setInt(i + 1, (Integer) param); - } else { - // Fallback for other types, or if OffsetDateTime was directly added to parameters list - pstmt.setObject(i + 1, param); - } + // --- LIMIT and OFFSET Clauses --- + if (input.getLimit() != null) { + sqlBuilder.append(" LIMIT ?"); + parameters.add(input.getLimit()); } + if (input.getOffset() != null) { + sqlBuilder.append(" OFFSET ?"); + parameters.add(input.getOffset()); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { - try (ResultSet rs = pstmt.executeQuery()) { - while (rs.next()) { - WorkflowStatus info = new WorkflowStatus(); - // The column names or their order in the SELECT statement must match. - info.setWorkflowId(rs.getString("workflow_uuid")); - info.setStatus(rs.getString("status")); - info.setName(rs.getString("name")); - info.setRecoveryAttempts(rs.getInt("recovery_attempts")); // getObject for nullable - info.setConfigName(rs.getString("config_name")); - info.setClassName(rs.getString("class_name")); - info.setAuthenticatedUser(rs.getString("authenticated_user")); - - String authenticatedRolesJson = rs.getString("authenticated_roles"); - if (authenticatedRolesJson != null) { - info.setAuthenticatedRoles(JSONUtil.fromJson(authenticatedRolesJson, new TypeReference>() {})); + for (int i = 0; i < parameters.size(); i++) { + + Object param = parameters.get(i); + if (param instanceof String) { + pstmt.setString(i + 1, (String) param); + } else if (param instanceof Long) { + pstmt.setLong(i + 1, (Long) param); + } else if (param instanceof Integer) { + pstmt.setInt(i + 1, (Integer) param); + } else { + // Fallback for other types, or if OffsetDateTime was directly added to parameters list + pstmt.setObject(i + 1, param); } + } - info.setAssumedRole(rs.getString("assumed_role")); - info.setQueueName(rs.getString("queue_name")); - info.setExecutorId(rs.getString("executor_id")); - info.setCreatedAt(rs.getObject("created_at", Long.class)); // getObject for nullable - info.setUpdatedAt(rs.getObject("updated_at", Long.class)); // getObject for nullable - info.setAppVersion(rs.getString("application_version")); - info.setAppId(rs.getString("application_id")); + try (ResultSet rs = pstmt.executeQuery()) { + while (rs.next()) { + WorkflowStatus info = new WorkflowStatus(); + // The column names or their order in the SELECT statement must match. + info.setWorkflowId(rs.getString("workflow_uuid")); + info.setStatus(rs.getString("status")); + info.setName(rs.getString("name")); + info.setRecoveryAttempts(rs.getInt("recovery_attempts")); // getObject for nullable + info.setConfigName(rs.getString("config_name")); + info.setClassName(rs.getString("class_name")); + info.setAuthenticatedUser(rs.getString("authenticated_user")); + + String authenticatedRolesJson = rs.getString("authenticated_roles"); + if (authenticatedRolesJson != null) { + info.setAuthenticatedRoles(JSONUtil.fromJson(authenticatedRolesJson, new TypeReference>() {})); + } - String serializedInput = rs.getString("inputs"); - String serializedOutput = rs.getString("output"); - String serializedError = rs.getString("error"); + info.setAssumedRole(rs.getString("assumed_role")); + info.setQueueName(rs.getString("queue_name")); + info.setExecutorId(rs.getString("executor_id")); + info.setCreatedAt(rs.getObject("created_at", Long.class)); // getObject for nullable + info.setUpdatedAt(rs.getObject("updated_at", Long.class)); // getObject for nullable + info.setAppVersion(rs.getString("application_version")); + info.setAppId(rs.getString("application_id")); - if (serializedInput != null) { - info.setInput((Object[])JSONUtil.deserialize((serializedInput)) ); - } + String serializedInput = rs.getString("inputs"); + String serializedOutput = rs.getString("output"); + String serializedError = rs.getString("error"); - if (serializedOutput != null) { - info.setOutput(JSONUtil.deserialize(serializedOutput)); - } + if (serializedInput != null) { + info.setInput((Object[])JSONUtil.deserialize((serializedInput)) ); + } - info.setError(serializedError); + if (serializedOutput != null) { + info.setOutput(JSONUtil.deserialize(serializedOutput)); + } - info.setWorkflowDeadlineEpochMs(rs.getObject("workflow_deadline_epoch_ms", Long.class)); - info.setWorkflowTimeoutMs(rs.getObject("workflow_timeout_ms", Long.class)); + info.setError(serializedError); - workflows.add(info); + info.setWorkflowDeadlineEpochMs(rs.getObject("workflow_deadline_epoch_ms", Long.class)); + info.setWorkflowTimeoutMs(rs.getObject("workflow_timeout_ms", Long.class)); + + workflows.add(info); + } } } + + + return workflows ; } + /** + * Helper method for tests + * Should be moved to TestUtils + */ + public void deleteWorkflowsTestHelper() throws SQLException{ - return workflows ; -} + String sql = "delete from dbos.workflow_status"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement pstmt = connection.prepareStatement(sql)) { + + int rowsAffected = pstmt.executeUpdate(); + logger.info("Cleaned up: Deleted " + rowsAffected + " rows from dbos.workflow_status"); + + } catch (SQLException e) { + logger.error("Error deleting workflows in test helper: " + e.getMessage()); + throw e; + } + + } -/** -* Helper method for tests - * Should be moved to TestUtils - */ -public void deleteWorkflowsTestHelper() throws SQLException{ + public void deleteOperations() throws SQLException{ - String sql = "delete from dbos.workflow_status"; + String sql = "delete from dbos.operation_outputs;"; - try (Connection connection = dataSource.getConnection(); - PreparedStatement pstmt = connection.prepareStatement(sql)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement pstmt = connection.prepareStatement(sql)) { + + int rowsAffected = pstmt.executeUpdate(); + logger.info("Cleaned up: Deleted " + rowsAffected + " rows from dbos.operation_outputs"); - int rowsAffected = pstmt.executeUpdate(); - logger.info("Cleaned up: Deleted " + rowsAffected + " rows from dbos.workflow_status"); + } catch (SQLException e) { + logger.error("Error deleting workflows in test helper: " + e.getMessage()); + throw e; + } - } catch (SQLException e) { - logger.error("Error deleting workflows in test helper: " + e.getMessage()); - throw e; } -} + public StepResult checkStepExecutionTxn( + String workflowId, + int functionId, + String functionName + ) throws IllegalStateException, WorkflowCancelledException, UnExpectedStepException { -public Object awaitWorkflowResult(String workflowId) throws Exception { + try { + try (Connection connection = dataSource.getConnection()) { + return stepsDAO.checkStepExecutionTxn(workflowId, functionId, functionName, connection); + } + } catch(SQLException sq) { + logger.error("Unexpected SQL exception", sq) ; + throw new DBOSException(UNEXPECTED.getCode(), sq.getMessage()) ; + } + } - final String sql = "SELECT status, output, error "+ - "FROM dbos.workflow_status " + - "WHERE workflow_uuid = ?" ; + public void recordStepResultTxn(StepResult result) { - while (true) { + try { + stepsDAO.recordStepResultTxn(result); + } catch(SQLException sq) { + logger.error("Unexpected SQL exception", sq) ; + throw new DBOSException(UNEXPECTED.getCode(), sq.getMessage()) ; + } - try (Connection connection = dataSource.getConnection(); - PreparedStatement stmt = connection.prepareStatement(sql)) { + } - stmt.setString(1, workflowId); + public List listWorkflowSteps(String workflowId) { + try { + return stepsDAO.listWorkflowSteps(workflowId); + } catch(SQLException sq) { + logger.error("Unexpected SQL exception", sq) ; + throw new DBOSException(UNEXPECTED.getCode(), sq.getMessage()) ; + } - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - String status = rs.getString("status"); + } + + public Object awaitWorkflowResult(String workflowId) throws Exception { + + final String sql = "SELECT status, output, error " + + "FROM dbos.workflow_status " + + "WHERE workflow_uuid = ?"; + + while (true) { + + try (Connection connection = dataSource.getConnection(); + PreparedStatement stmt = connection.prepareStatement(sql)) { + + stmt.setString(1, workflowId); - switch (WorkflowState.valueOf(status.toUpperCase())) { - case SUCCESS: - String output = rs.getString("output"); - return output != null ? JSONUtil.deserialize(output) : null; - case ERROR: - String error = rs.getString("error"); - // TODO fixException exception = serialization.deserializeException(error); - throw new Exception(error); - - case CANCELLED: - throw new AwaitedWorkflowCancelledException(workflowId); - - default: - // Status is PENDING or other - continue polling - break; + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + String status = rs.getString("status"); + + switch (WorkflowState.valueOf(status.toUpperCase())) { + case SUCCESS: + String output = rs.getString("output"); + return output != null ? JSONUtil.deserialize(output) : null; + case ERROR: + String error = rs.getString("error"); + // TODO fixException exception = serialization.deserializeException(error); + throw new Exception(error); + + case CANCELLED: + throw new AwaitedWorkflowCancelledException(workflowId); + + default: + // Status is PENDING or other - continue polling + break; } } // Row not found - workflow hasn't appeared yet, continue polling @@ -663,9 +721,10 @@ public Object awaitWorkflowResult(String workflowId) throws Exception { } + } -private void createDataSource(String dbName) { + private void createDataSource(String dbName) { HikariConfig hikariConfig = new HikariConfig(); String dburl = String.format("jdbc:postgresql://%s:%d/%s",config.getDbHost(),config.getDbPort(),dbName); diff --git a/src/main/java/dev/dbos/transact/exceptions/ErrorCode.java b/src/main/java/dev/dbos/transact/exceptions/ErrorCode.java index 9d6e82be..dc5a5c29 100644 --- a/src/main/java/dev/dbos/transact/exceptions/ErrorCode.java +++ b/src/main/java/dev/dbos/transact/exceptions/ErrorCode.java @@ -7,7 +7,9 @@ public enum ErrorCode { QUEUE_DUPLICATED(3) , DEAD_LETTER_QUEUE(4) , NONEXISTENT_WORKFLOW(5) , - AWAITED_WORKFLOW_CANCEL(6); + AWAITED_WORKFLOW_CANCEL(6), + WORKFLOW_CANCELLED(7), + UNEXPECTED_STEP(8); private int code ; diff --git a/src/main/java/dev/dbos/transact/exceptions/UnExpectedStepException.java b/src/main/java/dev/dbos/transact/exceptions/UnExpectedStepException.java new file mode 100644 index 00000000..c7ac480c --- /dev/null +++ b/src/main/java/dev/dbos/transact/exceptions/UnExpectedStepException.java @@ -0,0 +1,20 @@ +package dev.dbos.transact.exceptions; + +import static dev.dbos.transact.exceptions.ErrorCode.UNEXPECTED_STEP; + +public class UnExpectedStepException extends DBOSException { + private final String workflowId; + private final int stepId; + private final String expectedName; + private final String recordedName; + + public UnExpectedStepException(String workflowId, int stepId, String expectedName, String recordedName) { + super(UNEXPECTED_STEP.getCode(), + String.format("During execution of workflow %s step %s, function %s was recorded when %s was expected. Check that your workflow is deterministic.", + workflowId, stepId, recordedName, expectedName)); + this.workflowId = workflowId; + this.stepId = stepId ; + this.expectedName = expectedName; + this.recordedName = recordedName; + } +} diff --git a/src/main/java/dev/dbos/transact/exceptions/WorkflowCancelledException.java b/src/main/java/dev/dbos/transact/exceptions/WorkflowCancelledException.java new file mode 100644 index 00000000..833c6d39 --- /dev/null +++ b/src/main/java/dev/dbos/transact/exceptions/WorkflowCancelledException.java @@ -0,0 +1,8 @@ +package dev.dbos.transact.exceptions; + +public class WorkflowCancelledException extends DBOSException { + public WorkflowCancelledException(String workflowId) { + super(ErrorCode.WORKFLOW_CANCELLED.getCode(), + String.format("Workflow %s has been cancelled", workflowId)); + } +} diff --git a/src/main/java/dev/dbos/transact/execution/ContextAwareCallable.java b/src/main/java/dev/dbos/transact/execution/ContextAwareCallable.java index 554c1991..eceeaf3f 100644 --- a/src/main/java/dev/dbos/transact/execution/ContextAwareCallable.java +++ b/src/main/java/dev/dbos/transact/execution/ContextAwareCallable.java @@ -2,18 +2,23 @@ import dev.dbos.transact.context.DBOSContext; import dev.dbos.transact.context.DBOSContextHolder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.concurrent.Callable; public class ContextAwareCallable implements Callable { private final Callable task; - private final DBOSContext capturedContext; + private DBOSContext capturedContext; - public ContextAwareCallable(Callable task) { + Logger logger = LoggerFactory.getLogger(ContextAwareCallable.class) ; + + public ContextAwareCallable(DBOSContext ctx, Callable task) { this.task = task; - this.capturedContext = DBOSContextHolder.get(); + this.capturedContext = ctx; } + @Override public T call() throws Exception { DBOSContextHolder.set(capturedContext); diff --git a/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index 92f36623..32f1a438 100644 --- a/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -9,6 +9,7 @@ import dev.dbos.transact.workflow.WorkflowHandle; import dev.dbos.transact.workflow.WorkflowState; import dev.dbos.transact.workflow.WorkflowStatus; +import dev.dbos.transact.workflow.internal.StepResult; import dev.dbos.transact.workflow.internal.WorkflowHandleDBPoll; import dev.dbos.transact.workflow.internal.WorkflowHandleFuture; import dev.dbos.transact.workflow.internal.WorkflowStatusInternal; @@ -120,6 +121,9 @@ public T runWorkflow(String workflowName, if (wfid == null) { wfid = UUID.randomUUID().toString(); + ctx.setWorkflowId(wfid); + } else { + logger.info("workflowId from context ", wfid); } } @@ -137,7 +141,7 @@ public T runWorkflow(String workflowName, logger.warn("Idempotency check not impl for cancelled"); } - logger.info("Before executing workflow") ; + logger.info("Before executing workflow " + DBOSContextHolder.get().getWorkflowId()) ; T result = function.execute(); // invoke the lambda logger.info("After: Workflow completed successfully"); postInvokeWorkflow(initResult.getWorkflowId(), result); @@ -162,13 +166,19 @@ public WorkflowHandle submitWorkflow(String workflowName, if (workflowId == null) { workflowId = UUID.randomUUID().toString(); + ctx.setWorkflowId(workflowId); } final String wfId = workflowId ; Callable task = () -> { T result = null ; - logger.info("Callable executing the workflow.. " + wfId); + + // Doing this on purpose to ensure that we have the correct context + String id = DBOSContextHolder.get().getWorkflowId(); + + logger.info("Callable executing the workflow.. " + id); + try { result = runWorkflow(workflowName, @@ -176,7 +186,8 @@ public WorkflowHandle submitWorkflow(String workflowName, methodName, args, function, - wfId); + // wfId); doing it the hard way + id); } catch (Throwable e) { @@ -191,12 +202,93 @@ public WorkflowHandle submitWorkflow(String workflowName, return result ; }; - Future future = executorService.submit(task); + // Copy the context - dont just pass a reference - memory visibility + ContextAwareCallable contextAwareTask = new ContextAwareCallable<>(DBOSContextHolder.get().copy(),task); + Future future = executorService.submit(contextAwareTask); return new WorkflowHandleFuture(workflowId, future, systemDatabase); } + public T runStep(String stepName, + boolean retriedAllowed, + int maxAttempts, + float backOffRate, + Object[] args, + DBOSFunction function + ) throws Throwable { + + + DBOSContext ctx = DBOSContextHolder.get(); + String workflowId = ctx.getWorkflowId(); + + if (workflowId == null) { + throw new DBOSException(UNEXPECTED.getCode(), "No workflow id. Step must be called from workflow"); + } + logger.info(String.format("Running step %s for workflow %s", stepName, workflowId)) ; + + int stepFunctionId = ctx.getAndIncrementFunctionId() ; + + StepResult recordedResult = systemDatabase.checkStepExecutionTxn(workflowId, stepFunctionId, stepName) ; + + if (recordedResult != null) { + + String output = recordedResult.getOutput() ; + if (output != null) { + return (T) JSONUtil.deserialize(output) ; + } + + String error = recordedResult.getError(); + if (error != null) { + // TODO: fix deserialization of errors + throw new Exception(error); + } + } + + int currAttempts = 1 ; + String serializedOutput = null ; + Throwable eThrown = null ; + T result = null ; + + while (retriedAllowed && currAttempts <= maxAttempts) { + + try { + logger.info("Before executing step"); + result = function.execute(); + logger.info("After: step completed successfully " + result);// invoke the lambda + serializedOutput = JSONUtil.serialize(result); + logger.info("Json serialized output is " + serializedOutput); + eThrown = null ; + } catch(Exception e) { + // TODO: serialize + Throwable actual = (e instanceof InvocationTargetException) + ? ((InvocationTargetException) e).getTargetException() + : e; + logger.info("After: step threw exception " + actual.getMessage() + "-----" + actual.toString()) ; + eThrown = actual; + } + + ++currAttempts; + } + + if (eThrown == null) { + StepResult stepResult = new StepResult(workflowId, stepFunctionId, stepName, serializedOutput, null); + systemDatabase.recordStepResultTxn(stepResult); + return result; + } else { + // TODO: serialize + logger.info("After: step threw exception saving error " + eThrown.getMessage()) ; + StepResult stepResult = new StepResult(workflowId, stepFunctionId, stepName, null, eThrown.getMessage()); + systemDatabase.recordStepResultTxn(stepResult); + throw eThrown; + } + } + + + /** + * Retrieve the workflowHandle for the workflowId + * + */ public WorkflowHandle retrieveWorkflow(String workflowId) { return new WorkflowHandleDBPoll(workflowId, systemDatabase) ; } diff --git a/src/main/java/dev/dbos/transact/interceptor/AsyncInvocationHandler.java b/src/main/java/dev/dbos/transact/interceptor/AsyncInvocationHandler.java index 8a8580ac..5aef69de 100644 --- a/src/main/java/dev/dbos/transact/interceptor/AsyncInvocationHandler.java +++ b/src/main/java/dev/dbos/transact/interceptor/AsyncInvocationHandler.java @@ -1,6 +1,8 @@ package dev.dbos.transact.interceptor; import dev.dbos.transact.execution.DBOSExecutor; +import dev.dbos.transact.workflow.Step; +import dev.dbos.transact.workflow.Transaction; import dev.dbos.transact.workflow.Workflow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,32 +49,58 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl Method implMethod = target.getClass().getMethod(method.getName(), method.getParameterTypes()); - Workflow wfAnnotation = implMethod.getAnnotation(Workflow.class) ; + if (implMethod.isAnnotationPresent(Workflow.class)) { + return handleWorkflow(method, args, implMethod.getAnnotation(Workflow.class)); - if (wfAnnotation != null) { + } else if (implMethod.isAnnotationPresent(Step.class)) { + return handleStep(method, args, implMethod.getAnnotation(Step.class)); + } - String workflowName = wfAnnotation.name().isEmpty() ? implMethod.getName() : wfAnnotation.name(); + // No special annotation, proceed normally + return method.invoke(target, args); - String msg = String.format("Before: Starting workflow '%s' (timeout: %ds)%n", - workflowName, - wfAnnotation.timeout()); - logger.info(msg); + } - dbosExecutor.submitWorkflow( - workflowName, - targetClassName, - method.getName(), - args, - () -> (Object) method.invoke(target, args) - ); + protected Object handleWorkflow(Method method, Object[] args, Workflow workflow) throws Throwable { - return getDefaultValue(method.getReturnType()) ; // always return null or default + String workflowName = workflow.name().isEmpty() ? method.getName() : workflow.name(); - } else { - throw new RuntimeException("workflow annotation expected on target method"); - } + String msg = String.format("Before: Starting workflow '%s' (timeout: %ds)%n", + workflowName, + workflow.timeout()); + + logger.info(msg); + + dbosExecutor.submitWorkflow( + workflowName, + targetClassName, + method.getName(), + args, + () -> (Object) method.invoke(target, args) + ); + return getDefaultValue(method.getReturnType()) ; // always return null or default + + } + + protected Object handleStep(Method method, Object[] args, Step step) throws Throwable { + String msg = String.format("Before : Executing step %s %s", + method.getName(), step.name()); + logger.info(msg); + try { + Object result = dbosExecutor.runStep(step.name(), + step.retriesAllowed(), + step.maxAttempts(), + step.backOffRate(), + args, + ()-> method.invoke(target, args)) ; + logger.info("After: Step completed successfully"); + return result; + } catch (Exception e) { + logger.info("Step failed: " + e.getCause().getMessage()); + throw e.getCause(); + } } private Object getDefaultValue(Class returnType) { diff --git a/src/main/java/dev/dbos/transact/interceptor/TransactInvocationHandler.java b/src/main/java/dev/dbos/transact/interceptor/TransactInvocationHandler.java index 4efb76b5..a076205b 100644 --- a/src/main/java/dev/dbos/transact/interceptor/TransactInvocationHandler.java +++ b/src/main/java/dev/dbos/transact/interceptor/TransactInvocationHandler.java @@ -104,13 +104,15 @@ protected Object handleStep(Method method, Object[] args, Step step) throws Thro String msg = String.format("Before : Executing step %s %s", method.getName(), step.name()); logger.info(msg); - try { - Object result = method.invoke(target, args); + + Object result = dbosExecutor.runStep(step.name(), + step.retriesAllowed(), + step.maxAttempts(), + step.backOffRate(), + args, + ()-> method.invoke(target, args)) ; logger.info("After: Step completed successfully"); return result; - } catch (Exception e) { - logger.info("Step failed: " + e.getCause().getMessage()); - throw e.getCause(); - } + } } diff --git a/src/main/java/dev/dbos/transact/workflow/Step.java b/src/main/java/dev/dbos/transact/workflow/Step.java index 9f4d0868..a6b3f524 100644 --- a/src/main/java/dev/dbos/transact/workflow/Step.java +++ b/src/main/java/dev/dbos/transact/workflow/Step.java @@ -9,4 +9,7 @@ @Target(ElementType.METHOD) public @interface Step { String name() default ""; + boolean retriesAllowed() default true; + int maxAttempts() default 1; + float backOffRate() default 1.0f; } \ No newline at end of file diff --git a/src/main/java/dev/dbos/transact/workflow/StepInfo.java b/src/main/java/dev/dbos/transact/workflow/StepInfo.java new file mode 100644 index 00000000..2043e0f9 --- /dev/null +++ b/src/main/java/dev/dbos/transact/workflow/StepInfo.java @@ -0,0 +1,72 @@ +package dev.dbos.transact.workflow; + +public class StepInfo { + private final int functionId; + private final String functionName; + private final Object output; + private final Exception error; + private final String childWorkflowId; + + public StepInfo(int functionId, String functionName, Object output, Exception error, String childWorkflowId) { + this.functionId = functionId; + this.functionName = functionName; + this.output = output; + this.error = error; + this.childWorkflowId = childWorkflowId; + } + + public int getFunctionId() { + return functionId; + } + + public String getFunctionName() { + return functionName; + } + + public Object getOutput() { + return output; + } + + public Exception getError() { + return error; + } + + public String getChildWorkflowId() { + return childWorkflowId; + } + + @Override + public String toString() { + return "StepInfo{" + + "functionId=" + functionId + + ", functionName='" + functionName + '\'' + + ", output=" + output + + ", error=" + error + + ", childWorkflowId='" + childWorkflowId + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StepInfo stepInfo = (StepInfo) o; + + if (functionId != stepInfo.functionId) return false; + if (!functionName.equals(stepInfo.functionName)) return false; + if (output != null ? !output.equals(stepInfo.output) : stepInfo.output != null) return false; + if (error != null ? !error.equals(stepInfo.error) : stepInfo.error != null) return false; + return childWorkflowId != null ? childWorkflowId.equals(stepInfo.childWorkflowId) : stepInfo.childWorkflowId == null; + } + + @Override + public int hashCode() { + int result = functionId; + result = 31 * result + functionName.hashCode(); + result = 31 * result + (output != null ? output.hashCode() : 0); + result = 31 * result + (error != null ? error.hashCode() : 0); + result = 31 * result + (childWorkflowId != null ? childWorkflowId.hashCode() : 0); + return result; + } +} diff --git a/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java b/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java new file mode 100644 index 00000000..4b94c2c5 --- /dev/null +++ b/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java @@ -0,0 +1,31 @@ +package dev.dbos.transact.workflow.internal; + +public class StepResult { + private String workflowId; + private int functionId; + private String functionName; + private String output; + private String error; + + public StepResult() {} + + public StepResult(String workflowId, int functionID, String functionName, String output, String error) { + this.workflowId = workflowId; + this.functionId = functionID; + this.functionName = functionName; + this.output = output; + this.error = error; + } + + public String getWorkflowId() { return workflowId; } + public int getFunctionId() { return functionId; } + public String getFunctionName() { return functionName; } + public String getOutput() { return output; } + public String getError() { return error; } + + public void setWorkflowUUID(String workflowUUID) { this.workflowId = workflowUUID; } + public void setFunctionID(int functionID) { this.functionId = functionID; } + public void setFunctionName(String functionName) { this.functionName = functionName; } + public void setOutput(String output) { this.output = output; } + public void setError(String error) { this.error = error; } +} diff --git a/src/test/java/dev/dbos/transact/interceptor/OrderService.java b/src/test/java/dev/dbos/transact/interceptor/OrderService.java deleted file mode 100644 index 0091ef83..00000000 --- a/src/test/java/dev/dbos/transact/interceptor/OrderService.java +++ /dev/null @@ -1,10 +0,0 @@ -package dev.dbos.transact.interceptor; - -public interface OrderService { - - String processOrder(String item) ; - - String reserveInventory(String orderId, int itemId, int quantity); - - String chargeCustomer(String orderId, double amount); -} diff --git a/src/test/java/dev/dbos/transact/interceptor/OrderServiceImpl.java b/src/test/java/dev/dbos/transact/interceptor/OrderServiceImpl.java deleted file mode 100644 index 51befd18..00000000 --- a/src/test/java/dev/dbos/transact/interceptor/OrderServiceImpl.java +++ /dev/null @@ -1,27 +0,0 @@ -package dev.dbos.transact.interceptor; - -import dev.dbos.transact.workflow.Step; -import dev.dbos.transact.workflow.Transaction; -import dev.dbos.transact.workflow.Workflow; - -public class OrderServiceImpl implements OrderService { - - @Override - @Workflow(name = "processOrder") - public String processOrder(String item) { - return "Processed: " + item; - } - - @Override - @Step(name = "reserve") - public String reserveInventory(String orderId, int itemId, int quantity) { - return orderId + itemId + quantity ; - } - - @Override - @Transaction(name = "charge") - public String chargeCustomer(String orderId, double amount) { - return orderId+amount ; - - } -} diff --git a/src/test/java/dev/dbos/transact/interceptor/TransactInvocationHandlerTest.java b/src/test/java/dev/dbos/transact/interceptor/TransactInvocationHandlerTest.java deleted file mode 100644 index 95ecf54a..00000000 --- a/src/test/java/dev/dbos/transact/interceptor/TransactInvocationHandlerTest.java +++ /dev/null @@ -1,112 +0,0 @@ -package dev.dbos.transact.interceptor; - -import dev.dbos.transact.DBOS; -import dev.dbos.transact.config.DBOSConfig; -import dev.dbos.transact.database.SystemDatabase; -import dev.dbos.transact.execution.DBOSExecutor; -import dev.dbos.transact.workflow.Workflow; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; - -import java.lang.reflect.Method; -import java.lang.reflect.Proxy; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; - -class TransactInvocationHandlerTest { - - // Once we write more workflow tests, this test can be removed - - @Test - void invokeWorkflow() throws Throwable { - - OrderServiceImpl impl = new OrderServiceImpl(); - - DBOSExecutor executor = mock(DBOSExecutor.class) ; - - doReturn("Processed: test-item").when(executor).runWorkflow(anyString(), anyString(), anyString(), any(Object[].class), any(),any()); - - TransactInvocationHandler realHandler = - new TransactInvocationHandler(impl, executor); - - TransactInvocationHandler spyHandler = Mockito.spy(realHandler); - - OrderService proxy = (OrderService) Proxy.newProxyInstance( - OrderService.class.getClassLoader(), - new Class[]{OrderService.class}, - spyHandler - ); - - String result = proxy.processOrder("test-item"); - assertEquals("Processed: test-item", result); - - // Assert - Method expectedMethod = OrderService.class.getMethod("processOrder", String.class); - verify(spyHandler, times(1)).handleWorkflow(eq(expectedMethod), any(),any()); - - } - - @Test - void invokeStep() throws Throwable { - - OrderServiceImpl impl = new OrderServiceImpl(); - - DBOSExecutor executor = mock(DBOSExecutor.class) ; - doReturn(new SystemDatabase.WorkflowInitResult("121","PENDING",123L)).when(executor).preInvokeWorkflow(anyString(), - anyString(), anyString(), anyString(), any(Object[].class), any()); - doNothing().when(executor).postInvokeWorkflow(anyString(), any()); - - TransactInvocationHandler realHandler = - new TransactInvocationHandler(impl, executor); - - TransactInvocationHandler spyHandler = Mockito.spy(realHandler); - - OrderService proxy = (OrderService) Proxy.newProxyInstance( - OrderService.class.getClassLoader(), - new Class[]{OrderService.class}, - spyHandler - ); - - String result = proxy.reserveInventory("123",21, 1); - assertEquals("123211",result); - - // Assert - Method expectedMethod = OrderService.class.getMethod("reserveInventory", String.class, int.class, int.class); - verify(spyHandler, times(1)).handleStep(eq(expectedMethod), any(),any()); - - } - - @Test - void invokeTransaction() throws Throwable { - - OrderServiceImpl impl = new OrderServiceImpl(); - - DBOSExecutor executor = mock(DBOSExecutor.class) ; - doReturn(new SystemDatabase.WorkflowInitResult("121","PENDING",123L)).when(executor) - .preInvokeWorkflow(anyString(), anyString(), anyString(), anyString(), any(Object[].class),any()); - doNothing().when(executor).postInvokeWorkflow(anyString(), any()); - - TransactInvocationHandler realHandler = - new TransactInvocationHandler(impl , executor); - - TransactInvocationHandler spyHandler = Mockito.spy(realHandler); - - OrderService proxy = (OrderService) Proxy.newProxyInstance( - OrderService.class.getClassLoader(), - new Class[]{OrderService.class}, - spyHandler - ); - - String result = proxy.chargeCustomer("123",45.23); - assertEquals("12345.23",result); - - // Assert - Method expectedMethod = OrderService.class.getMethod("chargeCustomer", String.class, double.class); - verify(spyHandler, times(1)).handleTransaction(eq(expectedMethod), any(),any()); - - } -} - diff --git a/src/test/java/dev/dbos/transact/step/ServiceA.java b/src/test/java/dev/dbos/transact/step/ServiceA.java new file mode 100644 index 00000000..295bc828 --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceA.java @@ -0,0 +1,10 @@ +package dev.dbos.transact.step; + +public interface ServiceA { + + public void setServiceB(ServiceB serviceB) ; + + public String workflowWithSteps(String input) ; + public String workflowWithStepError(String input); + +} diff --git a/src/test/java/dev/dbos/transact/step/ServiceAImpl.java b/src/test/java/dev/dbos/transact/step/ServiceAImpl.java new file mode 100644 index 00000000..4054a539 --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceAImpl.java @@ -0,0 +1,56 @@ +package dev.dbos.transact.step; + +import dev.dbos.transact.context.DBOSContext; +import dev.dbos.transact.context.DBOSContextHolder; +import dev.dbos.transact.workflow.Workflow; + +public class ServiceAImpl implements ServiceA { + + private ServiceB serviceBproxy ; + + ServiceAImpl(ServiceB b) { + this.serviceBproxy = b ; + } + + public void setServiceB(ServiceB serviceB) { + serviceBproxy = serviceB ; + } + + @Workflow(name = "workflowWithSteps") + public String workflowWithSteps(String input) { + + DBOSContext ctx = DBOSContextHolder.get(); + String wfid = ctx.getWorkflowId() ; + + + serviceBproxy.step1("one"); + serviceBproxy.step2("two"); + try { + serviceBproxy.step3("three",false); + } catch(Exception e) { + // Nothing to do + System.out.println(e.getMessage()) ; + } + serviceBproxy.step4("four"); + serviceBproxy.step5("five"); + + return input+input; + } + + @Workflow(name = "workflowWithSteps") + public String workflowWithStepError(String input) { + + serviceBproxy.step1("one"); + serviceBproxy.step2("two"); + try { + serviceBproxy.step3("three",true); + } catch(Exception e) { + // Nothing to do + System.out.println(e.getMessage()) ; + } + serviceBproxy.step4("four"); + serviceBproxy.step5("five"); + + return input+input; + } +} diff --git a/src/test/java/dev/dbos/transact/step/ServiceB.java b/src/test/java/dev/dbos/transact/step/ServiceB.java new file mode 100644 index 00000000..c09beee8 --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceB.java @@ -0,0 +1,11 @@ +package dev.dbos.transact.step; + +public interface ServiceB { + + public String step1(String input) ; + public String step2(String input) ; + public String step3(String input, boolean throwError) throws Exception; + public String step4(String input) ; + public String step5(String input) ; + +} diff --git a/src/test/java/dev/dbos/transact/step/ServiceBImpl.java b/src/test/java/dev/dbos/transact/step/ServiceBImpl.java new file mode 100644 index 00000000..3a684abb --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceBImpl.java @@ -0,0 +1,36 @@ +package dev.dbos.transact.step; + +import dev.dbos.transact.workflow.Step; + +public class ServiceBImpl implements ServiceB { + + @Step(name = "step1") + public String step1(String input) { + return input; + } + + @Step(name = "step2") + public String step2(String input) { + return input; + } + + @Step(name = "step3") + public String step3(String input, boolean throwError) throws Exception { + if (throwError) { + throw new Exception("step3 error"); + } + + return input; + + } + + @Step(name = "step4") + public String step4(String input) { + return input ; + } + + @Step(name = "step5") + public String step5(String input) { + return input ; + } +} diff --git a/src/test/java/dev/dbos/transact/step/ServiceWFAndStep.java b/src/test/java/dev/dbos/transact/step/ServiceWFAndStep.java new file mode 100644 index 00000000..36d0c7a5 --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceWFAndStep.java @@ -0,0 +1,10 @@ +package dev.dbos.transact.step; + +public interface ServiceWFAndStep { + + void setSelf(ServiceWFAndStep serviceWFAndStep) ; + + String aWorkflow(String input) ; + String stepOne(String input) ; + String stepTwo(String input) ; +} diff --git a/src/test/java/dev/dbos/transact/step/ServiceWFAndStepImpl.java b/src/test/java/dev/dbos/transact/step/ServiceWFAndStepImpl.java new file mode 100644 index 00000000..4200c114 --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/ServiceWFAndStepImpl.java @@ -0,0 +1,31 @@ +package dev.dbos.transact.step; + +import dev.dbos.transact.workflow.Step; +import dev.dbos.transact.workflow.Workflow; + +public class ServiceWFAndStepImpl implements ServiceWFAndStep { + + private ServiceWFAndStep self ; + + public void setSelf(ServiceWFAndStep serviceWFAndStep) { + self = serviceWFAndStep; + } + + @Workflow(name = "myworkflow") + public String aWorkflow(String input) { + + String s1 = self.stepOne("one"); + String s2 = self.stepTwo("two"); + return input+s1+s2 ; + } + + @Step(name = "step1") + public String stepOne(String input) { + return input ; + } + + @Step(name = "step2") + public String stepTwo(String input) { + return input; + } +} diff --git a/src/test/java/dev/dbos/transact/step/StepsTest.java b/src/test/java/dev/dbos/transact/step/StepsTest.java new file mode 100644 index 00000000..39c26e6f --- /dev/null +++ b/src/test/java/dev/dbos/transact/step/StepsTest.java @@ -0,0 +1,228 @@ +package dev.dbos.transact.step; + +import dev.dbos.transact.DBOS; +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.context.SetWorkflowID; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.execution.DBOSExecutor; +import dev.dbos.transact.workflow.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class StepsTest { + + private static DBOSConfig dbosConfig; + private static DBOS dbos ; + private static SystemDatabase systemDatabase ; + private static DBOSExecutor dbosExecutor; + + @BeforeAll + static void onetimeSetup() throws Exception { + + StepsTest.dbosConfig = new DBOSConfig + .Builder() + .name("systemdbtest") + .dbHost("localhost") + .dbPort(5432) + .dbUser("postgres") + .sysDbName("dbos_java_sys") + .maximumPoolSize(2) + .build(); + + String dbUrl = String.format("jdbc:postgresql://%s:%d/%s", dbosConfig.getDbHost(), dbosConfig.getDbPort(), "postgres"); + + String sysDb = dbosConfig.getSysDbName(); + try (Connection conn = DriverManager.getConnection(dbUrl, dbosConfig.getDbUser(), dbosConfig.getDbPassword()); + Statement stmt = conn.createStatement()) { + + + String dropDbSql = String.format("DROP DATABASE IF EXISTS %s", sysDb); + String createDbSql = String.format("CREATE DATABASE %s", sysDb); + stmt.execute(dropDbSql); + stmt.execute(createDbSql); + } + + DBOS.initialize(dbosConfig); + dbos = DBOS.getInstance(); + SystemDatabase.initialize(dbosConfig); + systemDatabase = SystemDatabase.getInstance(); + dbosExecutor = new DBOSExecutor(dbosConfig, systemDatabase); + dbos.setDbosExecutor(dbosExecutor); + dbos.launch(); + + } + + @AfterAll + static void onetimeTearDown() { + dbos.shutdown(); + } + + @BeforeEach + void beforeEachTest() throws SQLException { + systemDatabase.deleteOperations(); + systemDatabase.deleteWorkflowsTestHelper(); + } + + + @Test + public void workflowWithStepsSync() throws SQLException { + + ServiceB serviceB = dbos.Workflow() + .interfaceClass(ServiceB.class) + .implementation(new ServiceBImpl()) + .build(); + + + ServiceA serviceA = dbos.Workflow() + .interfaceClass(ServiceA.class) + .implementation(new ServiceAImpl(serviceB)) + .build(); + + String wid = "sync123"; + + try (SetWorkflowID id = new SetWorkflowID(wid)) { + String result = serviceA.workflowWithSteps("hello"); + assertEquals("hellohello", result); + } + + List stepInfos = systemDatabase.listWorkflowSteps(wid); + assertEquals(5, stepInfos.size()); + + assertEquals("step1",stepInfos.get(0).getFunctionName()); + assertEquals(0,stepInfos.get(0).getFunctionId()); + assertEquals("one",stepInfos.get(0).getOutput()); + assertNull(stepInfos.get(0).getError()); + assertEquals("step2",stepInfos.get(1).getFunctionName()); + assertEquals(1,stepInfos.get(1).getFunctionId()); + assertEquals("two",stepInfos.get(1).getOutput()); + assertEquals("step3",stepInfos.get(2).getFunctionName()); + assertEquals(2,stepInfos.get(2).getFunctionId()); + assertEquals("three",stepInfos.get(2).getOutput()); + assertEquals("step4",stepInfos.get(3).getFunctionName()); + assertEquals(3,stepInfos.get(3).getFunctionId()); + assertEquals("four",stepInfos.get(3).getOutput()); + assertEquals("step5",stepInfos.get(4).getFunctionName()); + assertEquals(4,stepInfos.get(4).getFunctionId()); + assertEquals("five",stepInfos.get(4).getOutput()); + + } + + @Test + public void workflowWithStepsSyncError() throws SQLException { + + ServiceB serviceB = dbos.Workflow() + .interfaceClass(ServiceB.class) + .implementation(new ServiceBImpl()) + .build(); + + ServiceA serviceA = dbos.Workflow() + .interfaceClass(ServiceA.class) + .implementation(new ServiceAImpl(serviceB)) + .build(); + + String wid = "sync123er"; + try (SetWorkflowID id = new SetWorkflowID(wid)) { + String result = serviceA.workflowWithStepError("hello"); + assertEquals("hellohello", result); + } + + List stepInfos = systemDatabase.listWorkflowSteps(wid); + assertEquals(5, stepInfos.size()); + assertEquals("step3",stepInfos.get(2).getFunctionName()); + assertEquals(2,stepInfos.get(2).getFunctionId()); + Throwable error = stepInfos.get(2).getError(); + assertInstanceOf(Exception.class, error, "The error should be an Exception"); + assertEquals("step3 error", error.getMessage(), "Error message should match"); + assertNull(stepInfos.get(2).getOutput()) ; + } + + @Test + public void AsyncworkflowWithSteps() throws Exception { + + ServiceB serviceB = dbos.Workflow() + .interfaceClass(ServiceB.class) + .implementation(new ServiceBImpl()) + .build(); + + + ServiceA serviceA = dbos.Workflow() + .interfaceClass(ServiceA.class) + .implementation(new ServiceAImpl(serviceB)) + .async() + .build(); + + String workflowId = "wf-1234"; + + try (SetWorkflowID id = new SetWorkflowID(workflowId)) { + serviceA.workflowWithSteps("hello"); + } + + WorkflowHandle handle = dbosExecutor.retrieveWorkflow(workflowId); + assertEquals("hellohello", handle.getResult()); + + List stepInfos = systemDatabase.listWorkflowSteps(workflowId); + assertEquals(5, stepInfos.size()); + + assertEquals("step1",stepInfos.get(0).getFunctionName()); + assertEquals(0,stepInfos.get(0).getFunctionId()); + assertEquals("one",stepInfos.get(0).getOutput()); + assertEquals("step2",stepInfos.get(1).getFunctionName()); + assertEquals(1,stepInfos.get(1).getFunctionId()); + assertEquals("two",stepInfos.get(1).getOutput()); + assertEquals("step3",stepInfos.get(2).getFunctionName()); + assertEquals(2,stepInfos.get(2).getFunctionId()); + assertEquals("three",stepInfos.get(2).getOutput()); + assertEquals("step4",stepInfos.get(3).getFunctionName()); + assertEquals(3,stepInfos.get(3).getFunctionId()); + assertEquals("four",stepInfos.get(3).getOutput()); + assertEquals("step5",stepInfos.get(4).getFunctionName()); + assertEquals(4,stepInfos.get(4).getFunctionId()); + assertEquals("five",stepInfos.get(4).getOutput()); + assertNull(stepInfos.get(4).getError()); + + } + + @Test + public void SameInterfaceWorkflowWithSteps() throws Exception { + + ServiceWFAndStep service = dbos.Workflow() + .interfaceClass(ServiceWFAndStep.class) + .implementation(new ServiceWFAndStepImpl()) + .async() + .build(); + + service.setSelf(service); + + String workflowId = "wf-same-1234"; + + try (SetWorkflowID id = new SetWorkflowID(workflowId)) { + service.aWorkflow("hello"); + } + + WorkflowHandle handle = dbosExecutor.retrieveWorkflow(workflowId); + assertEquals("helloonetwo", handle.getResult()); + + List stepInfos = systemDatabase.listWorkflowSteps(workflowId); + assertEquals(2, stepInfos.size()); + + assertEquals("step1",stepInfos.get(0).getFunctionName()); + assertEquals(0,stepInfos.get(0).getFunctionId()); + assertEquals("one",stepInfos.get(0).getOutput()); + assertEquals("step2",stepInfos.get(1).getFunctionName()); + assertEquals(1,stepInfos.get(1).getFunctionId()); + assertEquals("two",stepInfos.get(1).getOutput()); + assertNull(stepInfos.get(1).getError()); + + } + +}