Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## [Unreleased]

- [Feature] Added distributed training support using CustomTrainingJobOp for scalable model training with worker replicas. Nodes can be configured for distributed training via node names or tags, supporting primary + worker pool architecture with configurable machine types and accelerators.

## [0.12.0] - 2025-03-11

- Support for python 3.11 & 3.12 added, dropped support for python 3.8
Expand Down
76 changes: 74 additions & 2 deletions kedro_vertexai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,38 @@
# allow_queueing: false
# max_run_count: none
# max_concurrent_run_count: 1

# Optional distributed training configuration
# distributed_training:
# # Enable distributed training for specific node names
# enabled_for_node_names:
# - "training_node"
# - "model_training"
#
# # Enable distributed training for nodes with specific tags
# enabled_for_tags:
# - "distributed"
# - "gpu-intensive"
#
# # Primary replica configuration (must have replica_count = 1)
# primary_pool:
# machine_type: "n1-standard-4"
# replica_count: 1
# accelerator_type: "NVIDIA_TESLA_T4"
# accelerator_count: 1
#
# # Worker pool configuration (can have replica_count > 1)
# worker_pool:
# machine_type: "n1-standard-4"
# replica_count: 2
# accelerator_type: "NVIDIA_TESLA_T4"
# accelerator_count: 1
#
# # Base output directory for distributed training jobs
# base_output_directory: "gs://your-bucket/distributed-training-output/"
#
# # Service account for distributed training (optional, defaults to global service_account)
# service_account: "[email protected]"
"""


Expand Down Expand Up @@ -170,6 +202,8 @@ class GroupingConfig(BaseModel):
def class_valid(cls, v, values, **kwargs):
try:
grouper_class = dynamic_load_class(v)
if grouper_class is None:
raise ValueError(f"Could not load grouping class {v}")
class_sig = signature(grouper_class)
if "params" in values.data:
class_sig.bind(None, **values.data["params"])
Expand Down Expand Up @@ -225,6 +259,22 @@ class ScheduleConfig(BaseModel):
max_concurrent_run_count: Optional[int] = 1


class WorkerPoolConfig(BaseModel):
machine_type: str = "n1-standard-4"
replica_count: int = 1
accelerator_type: Optional[str] = None
accelerator_count: Optional[int] = None


class DistributedTrainingConfig(BaseModel):
enabled_for_node_names: Optional[List[str]] = []
enabled_for_tags: Optional[List[str]] = []
primary_pool: Optional[WorkerPoolConfig] = WorkerPoolConfig()
worker_pool: Optional[WorkerPoolConfig] = WorkerPoolConfig(replica_count=2)
base_output_directory: Optional[str] = None
service_account: Optional[str] = None


class RunConfig(BaseModel):
image: str
root: Optional[str] = None
Expand All @@ -243,13 +293,35 @@ class RunConfig(BaseModel):
dynamic_config_providers: Optional[List[DynamicConfigProviderConfig]] = []
mlflow: Optional[MLFlowVertexAIConfig] = None
schedules: Optional[Dict[str, ScheduleConfig]] = None
distributed_training: Optional[DistributedTrainingConfig] = None

def resources_for(self, node: str, tags: Optional[set] = None):
if self.resources is None:
return {}
default_config = self.resources["__default__"].dict()
return self._config_for(node, tags, self.resources, default_config)
return self._config_for(node, tags or set(), self.resources, default_config)

def node_selectors_for(self, node: str, tags: Optional[set] = None):
return self._config_for(node, tags, self.node_selectors)
if self.node_selectors is None:
return {}
return self._config_for(node, tags or set(), self.node_selectors)

def should_use_distributed_training(self, node: str, tags: Optional[set] = None) -> bool:
"""Check if a node should use distributed training based on configuration."""
if not self.distributed_training:
return False

tags = tags or set()

# Check node names
if node in (self.distributed_training.enabled_for_node_names or []):
return True

# Check tags
if any(tag in (self.distributed_training.enabled_for_tags or []) for tag in tags):
return True

return False

@staticmethod
def _config_for(
Expand Down
159 changes: 142 additions & 17 deletions kedro_vertexai/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kfp import dsl
from kfp.dsl import PipelineTask
from makefun import with_signature
from google_cloud_pipeline_components.v1.custom_job import CustomTrainingJobOp

from kedro_vertexai.config import (
KedroVertexAIRunnerConfig,
Expand Down Expand Up @@ -201,25 +202,41 @@ def _build_kfp_tasks(
]
).strip()

@dsl.container_component
@with_signature(f"{name.replace('-', '_')}({params_signature})")
def component(*args, **kwargs):
dynamic_parameters = ",".join(
[f"{k}={kwargs[k]}" for k in params.keys()]
)

return dsl.ContainerSpec(
# Check if this node should use distributed training
if self.run_config.should_use_distributed_training(group_name, tags):
# Create CustomTrainingJobOp for distributed training
task = self._create_custom_training_job_task(
name=name,
image=image,
command=["/bin/bash", "-c"],
args=[
node_command,
" --params", # TODO what if there is no dynamic params?
f" {dynamic_parameters}",
],
kedro_command=kedro_command,
nodes_group=nodes_group,
tags=tags,
params_signature=params_signature,
component_params=component_params,
should_add_params=should_add_params,
)

task = component(**component_params)
self._configure_resources(name, tags, task)
else:
# Create standard container component
@dsl.container_component
@with_signature(f"{name.replace('-', '_')}({params_signature})")
def component(*args, **kwargs):
dynamic_parameters = ",".join(
[f"{k}={kwargs[k]}" for k in params.keys()]
)

return dsl.ContainerSpec(
image=image,
command=["/bin/bash", "-c"],
args=[
node_command,
" --params", # TODO what if there is no dynamic params?
f" {dynamic_parameters}",
],
)

task = component(**component_params)
self._configure_resources(name, tags, task)

kfp_tasks[name] = task

return kfp_tasks
Expand Down Expand Up @@ -249,6 +266,114 @@ def _generate_gcp_env_vars_command(self) -> str:
region = vertex_conf.get("region")
return f"GCP_PROJECT_ID={project_id} GCP_REGION={region}"

def _create_custom_training_job_task(
self,
name: str,
image: str,
kedro_command: str,
nodes_group: List,
tags: set,
params_signature: str,
component_params: Dict,
should_add_params: bool,
) -> PipelineTask:
"""Create a CustomTrainingJobOp task for distributed training."""

if not self.run_config.distributed_training:
raise ValueError("Distributed training config is required for CustomTrainingJobOp")

dt_config = self.run_config.distributed_training

# Ensure primary_pool and worker_pool are not None
if not dt_config.primary_pool:
raise ValueError("Primary pool configuration is required for distributed training")
if not dt_config.worker_pool:
raise ValueError("Worker pool configuration is required for distributed training")

# Build the full command with all necessary setup
full_command = " ".join([
h + " " if (h := self._generate_hosts_file()) else "",
self._generate_params_command(should_add_params),
"MLFLOW_RUN_ID=\"{{$.inputs.parameters['mlflow_run_id']}}\" "
if is_mlflow_enabled()
else "",
self._generate_gcp_env_vars_command(),
kedro_command,
]).strip()

# Build worker pool specs based on configuration
worker_pool_specs = []

# Get resource configuration from existing resources config
resources = self.run_config.resources_for(name, tags)

# Build primary spec with machine type from config or default
primary_machine_type = dt_config.primary_pool.machine_type
primary_spec = {
"machine_spec": {
"machine_type": primary_machine_type,
},
"replica_count": 1, # Primary must always be 1
"container_spec": {
"image_uri": image,
"command": ["/bin/bash", "-c"],
"args": [full_command],
},
}

# Add accelerator config from distributed training config
if dt_config.primary_pool.accelerator_type:
primary_spec["machine_spec"]["accelerator_type"] = dt_config.primary_pool.accelerator_type
if dt_config.primary_pool.accelerator_count:
primary_spec["machine_spec"]["accelerator_count"] = dt_config.primary_pool.accelerator_count

worker_pool_specs.append(primary_spec)

# Worker pool (can have replica_count > 1)
if dt_config.worker_pool.replica_count > 0:
worker_machine_type = dt_config.worker_pool.machine_type
worker_spec = {
"machine_spec": {
"machine_type": worker_machine_type,
},
"replica_count": dt_config.worker_pool.replica_count,
"container_spec": {
"image_uri": image,
"command": ["/bin/bash", "-c"],
"args": [full_command],
},
}

# Add accelerator config from distributed training config
if dt_config.worker_pool.accelerator_type:
worker_spec["machine_spec"]["accelerator_type"] = dt_config.worker_pool.accelerator_type
if dt_config.worker_pool.accelerator_count:
worker_spec["machine_spec"]["accelerator_count"] = dt_config.worker_pool.accelerator_count

worker_pool_specs.append(worker_spec)

# Create CustomTrainingJobOp task
# Note: While KFP compiler generates names like "comp-custom-training-job",
# "comp-custom-training-job-2", etc., creating unique function instances
# ensures each distributed training node gets processed separately.

# Use service account from distributed training config, fallback to global service account
service_account = dt_config.service_account or self.run_config.service_account or ""

# Create the task directly using CustomTrainingJobOp
task = CustomTrainingJobOp(
display_name=f"distributed-{name}",
worker_pool_specs=worker_pool_specs,
base_output_directory=dt_config.base_output_directory or f"gs://{self.run_config.root}/distributed-training-output/",
service_account=service_account,
**component_params
)

# Set display name to the node name for better identification on UI
task.set_display_name(name)

return task

def _configure_resources(self, name: str, tags: set, task: PipelineTask):
resources = self.run_config.resources_for(name, tags)
node_selectors = self.run_config.node_selectors_for(name, tags)
Expand Down
31 changes: 26 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pydantic = ">=2,<3"
google-auth = "<3"
google-cloud-scheduler = ">=2.3.2"
google-cloud-iam = "<3"
google-cloud-pipeline-components = ">=2.20.1"
gcsfs = ">=2022.1"
fsspec = ">=2022.1"
google-cloud-storage = "<3.0.0"
Expand Down
Loading