Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/main/java/dev/dbos/transact/context/DBOSContext.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package dev.dbos.transact.context;

public class DBOSContext {
private String workflowId;
private volatile String workflowId;
private String user;
private String functionId;
private volatile int functionId;
private String stepId;

public DBOSContext() {

}
public DBOSContext(String workflowId, int functionId) {
this.workflowId = workflowId;
this.functionId = functionId ;
}

public String getWorkflowId() {
return workflowId;
}
Expand All @@ -22,12 +30,12 @@ public void setUser(String user) {
this.user = user;
}

public String getFunctionId() {
public int getFunctionId() {
return functionId;
}

public void setFunctionId(String functionId) {
this.functionId = functionId;
public int getAndIncrementFunctionId() {
return functionId++;
}

public String getStepId() {
Expand All @@ -37,5 +45,9 @@ public String getStepId() {
public void setStepId(String stepId) {
this.stepId = stepId;
}

public DBOSContext copy() {
return new DBOSContext(workflowId, functionId);
}
}

Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
package dev.dbos.transact.context;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DBOSContextHolder {
private static final ThreadLocal<DBOSContext> contextHolder = ThreadLocal.withInitial(DBOSContext::new);
private static Logger logger = LoggerFactory.getLogger(DBOSContextHolder.class);

public static DBOSContext get() {
return contextHolder.get();
}

public static void clear() {
contextHolder.remove();
logger.debug("context cleared for thread " + Thread.currentThread().getId());
}

public static void set(DBOSContext context) {
contextHolder.set(context);
logger.debug("context set for thread " + Thread.currentThread().getId());
}

}
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/dev/dbos/transact/context/SetWorkflowID.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ public SetWorkflowID(String workflowId) {

@Override
public void close() {
DBOSContext context = DBOSContextHolder.get();
context.setWorkflowId(previousWorkflowId);
DBOSContextHolder.clear();
}
}

212 changes: 212 additions & 0 deletions src/main/java/dev/dbos/transact/database/StepsDAO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package dev.dbos.transact.database;

import dev.dbos.transact.Constants;
import dev.dbos.transact.exceptions.UnExpectedStepException;
import dev.dbos.transact.exceptions.WorkflowCancelledException;
import dev.dbos.transact.exceptions.DBOSWorkflowConflictException;
import dev.dbos.transact.json.JSONUtil;
import dev.dbos.transact.workflow.StepInfo;
import dev.dbos.transact.workflow.WorkflowState;
import dev.dbos.transact.workflow.internal.StepResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

public class StepsDAO {

private Logger logger = LoggerFactory.getLogger(StepsDAO.class);
private DataSource dataSource ;

StepsDAO(DataSource dataSource) {
this.dataSource = dataSource;
}

public void recordStepResultTxn(StepResult result) throws SQLException
{

String sql = String.format(
"INSERT INTO %s.operation_outputs (workflow_uuid, function_id, function_name, output, error) " +
"VALUES (?, ?, ?, ?, ?)",
Constants.DB_SCHEMA
);

try (Connection connection = dataSource.getConnection() ;
PreparedStatement pstmt = connection.prepareStatement(sql)) {
int paramIdx = 1;
pstmt.setString(paramIdx++, result.getWorkflowId());
pstmt.setInt(paramIdx++, result.getFunctionId());
pstmt.setString(paramIdx++, result.getFunctionName());

if (result.getOutput() != null) {
pstmt.setString(paramIdx++, result.getOutput());
} else {
pstmt.setNull(paramIdx++, Types.LONGVARCHAR);
}

if (result.getError() != null) {
pstmt.setString(paramIdx++, result.getError());
} else {
pstmt.setNull(paramIdx++, Types.LONGVARCHAR);
}

pstmt.executeUpdate();

} catch (SQLException e) {
if ("23505".equals(e.getSQLState())) {
throw new DBOSWorkflowConflictException(result.getWorkflowId(),
String.format("Workflow %s already exists", result.getWorkflowId()));
} else {
throw e;
}
}
}

/**
* Checks the execution status and output of a specific operation within a workflow.
* This method corresponds to Python's '_check_operation_execution_txn'.
*
* @param workflowId The UUID of the workflow.
* @param functionId The ID of the function/operation.
* @param functionName The expected name of the function/operation.
* @param connection The active JDBC connection (corresponding to Python's 'conn: sa.Connection').
* @return A {@link StepResult} object if the operation has completed, otherwise {@code null}.
* @throws IllegalStateException If the workflow does not exist in the status table.
* @throws WorkflowCancelledException If the workflow is in a cancelled status.
* @throws UnExpectedStepException If the recorded function name for the operation does not match the provided name.
* @throws SQLException For other database access errors.
*/
public StepResult checkStepExecutionTxn(
String workflowId,
int functionId,
String functionName,
Connection connection
) throws SQLException, IllegalStateException, WorkflowCancelledException, UnExpectedStepException {

String workflowStatusSql = String.format(
"SELECT status FROM %s.workflow_status WHERE workflow_uuid = ?",
Constants.DB_SCHEMA
);

String workflowStatus = null;
try (PreparedStatement pstmt = connection.prepareStatement(workflowStatusSql)) {
pstmt.setString(1, workflowId);
try (ResultSet rs = pstmt.executeQuery()) {
if (rs.next()) {
workflowStatus = rs.getString("status");
}
}
}

if (workflowStatus == null) {
throw new IllegalStateException(String.format("Error: Workflow %s does not exist", workflowId));
}

if (Objects.equals(workflowStatus, WorkflowState.CANCELLED.name())) {
throw new WorkflowCancelledException(
String.format("Workflow %s is cancelled. Aborting function.", workflowId)
);
}

String operationOutputSql = String.format(
"SELECT output, error, function_name " +
"FROM %s.operation_outputs " +
"WHERE workflow_uuid = ? AND function_id = ?",
Constants.DB_SCHEMA
);

StepResult recordedResult = null;
String recordedFunctionName = null;

try (PreparedStatement pstmt = connection.prepareStatement(operationOutputSql)) {
pstmt.setString(1, workflowId);
pstmt.setInt(2, functionId);
try (ResultSet rs = pstmt.executeQuery()) {
if (rs.next()) { // Check if any operation output row exists
String output = rs.getString("output");
String error = rs.getString("error");
recordedFunctionName = rs.getString("function_name");
recordedResult = new StepResult(workflowId, functionId, recordedFunctionName, output, error);
}
}
}

if (recordedResult == null) {
return null;
}

if (!Objects.equals(functionName, recordedFunctionName)) {
throw new UnExpectedStepException(
workflowId,
functionId,
functionName,
recordedFunctionName
);
}

return recordedResult;
}

public List<StepInfo> listWorkflowSteps(String workflowId) throws SQLException {
String sqlTemplate = "SELECT function_id, function_name, output, error, child_workflow_id " +
"FROM %s.operation_outputs " +
"WHERE workflow_uuid = ? " +
"ORDER BY function_id;";
final String sql = String.format(sqlTemplate, Constants.DB_SCHEMA);
System.out.println(sql);


List<StepInfo> steps = new ArrayList<>();

try (Connection connection = dataSource.getConnection();
PreparedStatement stmt = connection.prepareStatement(sql)) {

stmt.setString(1, workflowId);

try (ResultSet rs = stmt.executeQuery()) {

while (rs.next()) {
int functionId = rs.getInt("function_id");
String functionName = rs.getString("function_name");
String outputData = rs.getString("output");
String errorData = rs.getString("error");
String childWorkflowId = rs.getString("child_workflow_id");
System.out.println(functionId);

// Deserialize output if present
Object output = null;
if (outputData != null) {
try {
output = JSONUtil.deserialize(outputData);
} catch (Exception e) {
throw new RuntimeException("Failed to deserialize output for function " + functionId, e);
}
}

// Deserialize error if present
Exception error = null;
if (errorData != null) {
try {
// TODO error = JSONUtil.deserialize(errorData);
error = new Exception(errorData) ;
} catch (Exception e) {
throw new RuntimeException("Failed to deserialize error for function " + functionId, e);
}
}

steps.add(new StepInfo(functionId, functionName, output, error, childWorkflowId));
}
}
} catch (SQLException e) {
throw new SQLException("Failed to retrieve workflow steps for workflow: " + workflowId, e);
}

return steps;
}
}


Loading
Loading