Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 12 additions & 48 deletions core/services/workflow_service/controllers/project_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def create_project(db: Session, name: str, current_user_uuid: UUID) -> UUID:


def create_project_from_template(
db: Session,
name: str,
template_identifier: str,
current_user_uuid: UUID
Expand All @@ -41,8 +42,6 @@ def create_project_from_template(
This method will handle the creation of project, blocks and edges as
defined in the template.yaml
"""
db: Session = next(get_database())

template: WorkflowTemplate =\
template_controller.get_workflow_template_by_identifier(
template_identifier
Expand Down Expand Up @@ -76,31 +75,15 @@ def create_project_from_template(
raise e


def read_project(project_uuid: UUID) -> Project:
logging.debug(f"Reading project with UUID: {project_uuid}")
db: Session = next(get_database())

project = db.query(Project).filter_by(uuid=project_uuid).one_or_none()

if not project:
logging.error(f"Project {project_uuid} not found")
raise HTTPException(status_code=404, detail="Project not found")

return project


def rename_project(project_uuid: UUID, new_name: str, db: Session) -> Project:
logging.debug(f"Renaming project {project_uuid} to {new_name}.")

project = db.query(Project).filter_by(uuid=project_uuid).one_or_none()

if not project:
logging.error(f"Project {project_uuid} not found.")
raise HTTPException(status_code=404, detail="Project not found")
def rename_project(db: Session, project: Project, new_name: str) -> Project:
logging.debug(f"Renaming project {project.uuid} to {new_name}.")

project.name = new_name

logging.info(f"Project {project_uuid} renamed successfully to {new_name}")
db.commit()
db.refresh(project)

logging.debug(f"Project {project.uuid} renamed successfully to {new_name}")
return project


Expand Down Expand Up @@ -154,47 +137,28 @@ def delete_user(project_uuid: UUID, user_uuid: UUID) -> None:
logging.info(f"User {user_uuid} removed from project {project_uuid}")


def delete_project(project_uuid: UUID) -> None:
logging.debug(f"Deleting project with UUID: {project_uuid}")
db: Session = next(get_database())

project = db.query(Project).filter_by(uuid=project_uuid).one_or_none()

if not project:
logging.error(f"Project {project_uuid} not found")
raise HTTPException(status_code=404, detail="Project not found")
def delete_project(db: Session, project: Project) -> None:
logging.debug(f"Deleting project with UUID: {project.uuid}")

db.delete(project)
db.commit()

logging.info(f"Project {project_uuid} deleted successfully")

logging.debug("Project deleted successfully")

def read_all_projects() -> list[Project]:
db: Session = next(get_database())

def read_all_projects(db: Session) -> list[Project]:
projects = db.query(Project).all()

return projects


def read_projects_by_user_uuid(user_uuid: UUID) -> list[Project]:
def read_projects_by_user_uuid(db: Session, user_uuid: UUID) -> list[Project]:
logging.debug(f"Fetching projects for user UUID: {user_uuid}")
db: Session = next(get_database())

projects = (
db.query(Project)
.filter(Project.users.contains([user_uuid]))
.all()
)

if not projects:
logging.error(f"No projects found for user {user_uuid}")
raise HTTPException(
status_code=404,
detail="No projects found for user",
)

logging.info(f"Retrieved {len(projects)} projects for user {user_uuid}")

return projects
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
compute_block_controller,
template_controller,
)
from services.workflow_service.controllers.project_controller import (
read_project,
)
from services.workflow_service.models.block import (
Block,
block_dependencies,
Expand All @@ -45,14 +42,14 @@
BlockStatus,
ConfigType,
)
from services.workflow_service.models import Project
from services.workflow_service.schemas.workflow import (
WorfklowValidationError,
WorkflowEnvsWithBlockInfo,
WorkflowTemplate,
)
from utils.config.environment import ENV
from utils.data.file_handling import bulk_presigned_urls_from_ios
from utils.database.session_injector import get_database

if TYPE_CHECKING:
from uuid import UUID
Expand Down Expand Up @@ -168,7 +165,7 @@ def _get_unconfigured_ios(ios: list[InputOutput]) -> list[InputOutput]:
return result


def get_workflow_configurations(project_id: UUID) -> tuple[
def get_workflow_configurations(db: Session, project_id: UUID) -> tuple[
list[WorkflowEnvsWithBlockInfo],
list[InputOutput], # Workflow Inputs
list[InputOutput], # Intermediates
Expand Down Expand Up @@ -226,8 +223,6 @@ def get_workflow_configurations(project_id: UUID) -> tuple[
- List of InputOutput for workflow outputs
- Dictionary mapping entrypoint UUIDs to Block instances
"""
db: Session = next(get_database())

# 1. Load blocks
blocks = compute_block_controller.get_compute_blocks_by_project(project_id)
block_by_entry_id = {b.selected_entrypoint_uuid: b for b in blocks}
Expand Down Expand Up @@ -478,10 +473,9 @@ def wait_for_dag_registration(
return False


def translate_project_to_dag(project_uuid: UUID) -> str:
def translate_project_to_dag(project: Project, project_uuid: UUID) -> str:
"""Parses a project and its blocks into a DAG, validates it, and saves
it."""
project = read_project(project_uuid)
graph = create_graph(project)
templates = init_templates()
dag_id = _project_id_to_dag_id(project_uuid)
Expand Down
8 changes: 8 additions & 0 deletions core/services/workflow_service/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .block import Block
from .entrypoint import Entrypoint
from .input_output import InputOutput, InputOutputType
from .project import Project

__all__ = [
"Block", "Entrypoint", "InputOutput", "Project"
]
3 changes: 2 additions & 1 deletion core/services/workflow_service/schemas/compute_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ def replace_minio_host(url: str | None) -> str | None:

class BaseIODTO(BaseModel):
id: UUID | None = None
name: str | None = None
data_type: DataType

@classmethod
def from_input_output(cls, io):
return cls(
id=io.uuid,
name=io.name,
data_type=io.data_type
)

Expand Down Expand Up @@ -353,7 +355,6 @@ def from_sdk_compute_block(cls, cb):


class CreateComputeBlockRequest(BaseModel):
project_id: UUID
cbc_url: str
name: str
custom_name: str
Expand Down
1 change: 0 additions & 1 deletion core/services/workflow_service/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ReadAllResponse(BaseModel):


class RenameProjectRequest(BaseModel):
project_uuid: UUID
new_name: str


Expand Down
56 changes: 27 additions & 29 deletions core/services/workflow_service/views/compute_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@
update_block,
update_ios_with_uploads
)
from services.workflow_service.models import (
Project,
Entrypoint,
InputOutput
)
from utils.security.token import User, get_user
from utils.security.resources import (
get_project,
get_entrypoint,
get_ios_by_entrypoint_uuid
)

router = APIRouter(prefix="/compute_block", tags=["compute_block"])

Expand Down Expand Up @@ -94,19 +104,15 @@ async def create(


@router.get(
"/by_project/{project_id}",
"/by_project/{project_uuid}",
response_model=GetNodesByProjectResponse,
)
async def get_by_project(
project_id: UUID | None = None,
_: User = Depends(get_user),
project: Project = Depends(get_project),
):
if not project_id:
raise HTTPException(status_code=422, detail="Project ID is required.")

try:
compute_blocks = get_compute_blocks_by_project(project_id)
status = workflow_controller.dag_status(project_id)
compute_blocks = get_compute_blocks_by_project(project.uuid)
status = workflow_controller.dag_status(project.uuid)

block_uuids = [block.uuid for block in compute_blocks]
dependencies = get_block_dependencies_for_blocks(block_uuids)
Expand All @@ -126,21 +132,15 @@ async def get_by_project(
raise handle_error(e)


@router.get("/entrypoint/{entry_id}/envs/", response_model=ConfigType)
@router.get("/entrypoint/{entrypoint_uuid}/envs/", response_model=ConfigType)
async def get_envs(
entry_id: UUID | None = None,
_: User = Depends(get_user),
entrypoint: Entrypoint = Depends(get_entrypoint)
):
if not entry_id:
raise HTTPException(
status_code=422,
detail="Entrypoint ID is required.",
)

try:
return get_envs_for_entrypoint(entry_id)
return get_envs_for_entrypoint(entrypoint.uuid)
except Exception as e:
logging.exception(f"Error getting envs of entrypoint {entry_id}: {e}")
logging.exception(f"Error getting envs of entrypoint {
entrypoint.uuid}: {e}")
raise handle_error(e)


Expand Down Expand Up @@ -169,19 +169,16 @@ async def update_compute_block(
raise handle_error(e)


@router.get("/entrypoint/{entry_id}/io/", response_model=list[InputOutputDTO])
@router.get(
"/entrypoint/{entrypoint_uuid}/io/",
response_model=list[InputOutputDTO]
)
async def get_io(
entry_id: UUID,
io_type: InputOutputType,
_: User = Depends(get_user),
entrypoint_uuid: UUID,
ios: list[InputOutput] = Depends(get_ios_by_entrypoint_uuid)
):
if not entry_id:
raise HTTPException(
status_code=422,
detail="Entrypoint ID is required.",
)
try:
ios = get_io_for_entrypoint(entry_id, io_type)
presigned_urls = bulk_presigned_urls_from_ios(ios)
return [InputOutputDTO.from_input_output(
io.name,
Expand All @@ -190,7 +187,8 @@ async def get_io(
]
except Exception as e:
logging.exception(
f"Error getting {io_type.value}s of entrypoint {entry_id}: {e}",
f"Error getting {io_type.value}s of entrypoint {
entrypoint_uuid}: {e}",
)
raise handle_error(e)

Expand Down
Loading
Loading