Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions pkg/scheduler/plugins/proportion/capacity_policy/capacity_policy.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
// Copyright 2025 NVIDIA CORPORATION
// SPDX-License-Identifier: Apache-2.0

// Package capacity_policy implements queue capacity and quota checking functionality
// for the KAI scheduler. It ensures that jobs do not exceed their queue's resource
// quotas, both at the direct queue level and parent queue levels.
package capacity_policy

import (
"fmt"

"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_status"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/resource_info"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/constants"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/framework"
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
rs "github.com/NVIDIA/KAI-scheduler/pkg/scheduler/plugins/proportion/resource_share"
)

// capacityCheckFn is a function type that checks if a job's requested resources
// exceed capacity limits. It returns a SchedulableResult indicating whether the
// job can be scheduled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a style comment, because this is a new project and all, but we don't have many obvious comments here (the comment really doesn't add anything that is not written in the next line)

type capacityCheckFn func(requestedShare rs.ResourceQuantities, job *podgroup_info.PodGroupInfo) *api.SchedulableResult

// CapacityPolicy implements queue capacity checking and quota enforcement.
// It tracks queue hierarchies and ensures jobs don't exceed resource quotas
// at any level in the hierarchy.
type CapacityPolicy struct {
queues map[common_info.QueueID]*rs.QueueAttributes
isInferencePreemptible bool
}

// New creates a new CapacityPolicy instance with the given queue attributes
// and inference preemption configuration.
func New(queues map[common_info.QueueID]*rs.QueueAttributes, isInferencePreemptible bool) *CapacityPolicy {
return &CapacityPolicy{queues, isInferencePreemptible}
}

// IsJobOverQueueCapacity checks if a job would exceed its queue's capacity
// when considering all tasks that need to be allocated. This includes both
// regular capacity limits and non-preemptible quota checks.
func (cp *CapacityPolicy) IsJobOverQueueCapacity(job *podgroup_info.PodGroupInfo,
tasksToAllocate []*pod_info.PodInfo) *api.SchedulableResult {
requiredQuota := getRequiredQuota(tasksToAllocate)
Expand All @@ -36,6 +56,9 @@ func (cp *CapacityPolicy) IsJobOverQueueCapacity(job *podgroup_info.PodGroupInfo
return cp.isJobOverCapacity(requestedShareQuantities, job, checkFns)
}

// IsNonPreemptibleJobOverQuota specifically checks if a non-preemptible job
// would exceed its queue's quota. This is a stricter check than regular
// capacity checking as non-preemptible jobs have dedicated resource quotas.
func (cp *CapacityPolicy) IsNonPreemptibleJobOverQuota(job *podgroup_info.PodGroupInfo,
tasksToAllocate []*pod_info.PodInfo) *api.SchedulableResult {

Expand All @@ -49,6 +72,9 @@ func (cp *CapacityPolicy) IsNonPreemptibleJobOverQuota(job *podgroup_info.PodGro
return cp.isJobOverCapacity(requestedShareQuantities, job, checkFns)
}

// IsTaskAllocationOnNodeOverCapacity checks if allocating a specific task
// to a node would exceed capacity limits. This considers both the node's
// resources and the queue's capacity constraints.
func (cp *CapacityPolicy) IsTaskAllocationOnNodeOverCapacity(task *pod_info.PodInfo, job *podgroup_info.PodGroupInfo,
node *node_info.NodeInfo) *api.SchedulableResult {
requiredInitQuota := node.GetRequiredInitQuota(task)
Expand All @@ -61,6 +87,8 @@ func (cp *CapacityPolicy) IsTaskAllocationOnNodeOverCapacity(task *pod_info.PodI
return cp.isJobOverCapacity(requestedShare, job, checkFns)
}

// isJobOverCapacity is an internal helper that runs a series of capacity
// check functions to determine if a job exceeds any resource limits.
func (cp *CapacityPolicy) isJobOverCapacity(requestedShare rs.ResourceQuantities, job *podgroup_info.PodGroupInfo,
checkFns []capacityCheckFn) *api.SchedulableResult {
for _, checkFn := range checkFns {
Expand All @@ -74,6 +102,8 @@ func (cp *CapacityPolicy) isJobOverCapacity(requestedShare rs.ResourceQuantities
return Schedulable()
}

// getRequiredQuota calculates the total resource requirements for a set of tasks.
// This includes CPU, Memory, and GPU resources.
func getRequiredQuota(tasksToAllocate []*pod_info.PodInfo) *podgroup_info.JobRequirement {
quota := podgroup_info.JobRequirement{}
for _, pod := range tasksToAllocate {
Expand All @@ -83,3 +113,131 @@ func getRequiredQuota(tasksToAllocate []*pod_info.PodInfo) *podgroup_info.JobReq
}
return &quota
}

// getFirstPendingPod returns the first pod in a job that is in Pending status.
// This is used to avoid duplicate quota checks for the same job.
func getFirstPendingPod(job *podgroup_info.PodGroupInfo) *pod_info.PodInfo {
for _, pod := range job.PodInfos {
if pod.Status == pod_status.Pending {
return pod
}
}
return nil
}

// OnSessionOpen is called when a new scheduling session begins. It registers
// the early quota checking function that prevents jobs from being considered
// for scheduling if they would exceed their parent queues' quotas.
func (cp *CapacityPolicy) OnSessionOpen(ssn *framework.Session) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something, but I don't think this is called by the changes in the current PR because this is not a plugin that is registered.

// Register early quota checks
ssn.AddPrePredicateFn(func(task *pod_info.PodInfo, job *podgroup_info.PodGroupInfo) error {
// Only check for the first pending pod to avoid duplicate checks
firstPending := getFirstPendingPod(job)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the fact that the job needs scheduling (and then predicates are running on it) is not enough to understand that it has pending pods?

if firstPending == nil || task != firstPending {
return nil
}

// Check parent queue quotas
return cp.checkParentQueueQuotas(job, ssn)
})
}

// checkParentQueueQuotas verifies that a job's resource requirements don't
// exceed quotas at any level in its queue hierarchy. This includes:
// - GPU quota checks
// - CPU quota checks
// - Memory quota checks
//
// The function traverses up the queue hierarchy starting from the job's
// immediate parent queue. If any quota would be exceeded, it returns an
// error with a detailed message.
//
// Note: Preemptible jobs (PriorityTrainNumber) are allowed to exceed parent
// queue quotas, while non-preemptible jobs must strictly adhere to quotas.
func (cp *CapacityPolicy) checkParentQueueQuotas(job *podgroup_info.PodGroupInfo, ssn *framework.Session) error {
// Skip quota checks for preemptible jobs
if job.Priority == constants.PriorityTrainNumber {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aim to separate the job priority and preemptibility soon, and the Train priority is not the only one that is currently preemptible.
We do have several scalars to define the resource of a queue: quota, limit and over quota weight. We should look at limit here and not quota and then it will be correct for all jobs.

log.InfraLogger.V(5).Infof("Job: <%v/%v> is preemptible, skipping parent queue quota checks", job.Namespace, job.Name)
return nil
}

// Get queue info for this job
queue, found := ssn.Queues[job.Queue]
if !found {
return nil // Can't check quota without queue info
}

// Only check parent queues, not the job's direct queue
currentQueueID := queue.ParentQueue

for currentQueueID != "" {
parentQueue, found := ssn.Queues[currentQueueID]
if !found {
break
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This together with line 239 can be one line of for with initialization, check and step


// Calculate job's total resource requirements
jobResources := resource_info.EmptyResource()
for _, pod := range job.PodInfos {
if pod.Status == pod_status.Pending {
jobResources.AddResourceRequirements(pod.ResReq)
}
}

// Check GPU quota
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part seems duplicated 3 times, there are functions to avoid it in the code base. you can loop on the resource type and the . look at proportion/proportion.go for example.

if parentQueue.Resources.GPU.Quota > 0 && jobResources.GPUs() > float64(parentQueue.Resources.GPU.Quota) {
errorMsg := fmt.Sprintf(
"parent queue '%s' quota has reached the allowable limit of GPUs. "+
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A queue can go over the quota, but not the limit. I would change the whole thing to address the limit.

"Limit is %.0f GPUs, workload requested %.0f GPUs",
parentQueue.Name,
parentQueue.Resources.GPU.Quota,
jobResources.GPUs())

// Record event
if firstPod := getFirstPendingPod(job); firstPod != nil {
log.InfraLogger.Warningf("Queue quota exceeded: %s", errorMsg)
}

return fmt.Errorf(errorMsg)
}

// Check CPU quota
if parentQueue.Resources.CPU.Quota > 0 && jobResources.Cpu() > float64(parentQueue.Resources.CPU.Quota) {
errorMsg := fmt.Sprintf(
"parent queue '%s' quota has reached the allowable limit of CPU. "+
"Limit is %.0f CPU, workload requested %.0f CPU",
parentQueue.Name,
parentQueue.Resources.CPU.Quota,
jobResources.Cpu())

// Record event
if firstPod := getFirstPendingPod(job); firstPod != nil {
log.InfraLogger.Warningf("Queue quota exceeded: %s", errorMsg)
}

return fmt.Errorf(errorMsg)
}

// Check Memory quota
if parentQueue.Resources.Memory.Quota > 0 && jobResources.Memory() > float64(parentQueue.Resources.Memory.Quota) {
errorMsg := fmt.Sprintf(
"parent queue '%s' quota has reached the allowable limit of Memory. "+
"Limit is %.0f Memory, workload requested %.0f Memory",
parentQueue.Name,
parentQueue.Resources.Memory.Quota,
jobResources.Memory())

// Record event
if firstPod := getFirstPendingPod(job); firstPod != nil {
log.InfraLogger.Warningf("Queue quota exceeded: %s", errorMsg)
}

return fmt.Errorf(errorMsg)
}

// Move up the hierarchy
currentQueueID = parentQueue.ParentQueue
}

return nil
}
Loading