Skip to content

Commit ba2021a

Browse files
committed
[https://nvbugs/5527655][feat] Add NUMA-aware CPU affinity autoconfiguration
Signed-off-by: Dan Hansen <[email protected]>
1 parent ec31363 commit ba2021a

File tree

4 files changed

+68
-26
lines changed

4 files changed

+68
-26
lines changed

tensorrt_llm/executor/base_worker.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import datetime
33
import enum
44
import json
5+
import os
56
import weakref
67
from pathlib import Path
78
from queue import Queue
89
from typing import Dict, List, Optional, Tuple, Union
910

11+
import psutil
1012
import torch
1113

1214
from tensorrt_llm.logger import logger
@@ -19,7 +21,7 @@
1921
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror
2022
from ..llmapi.tokenizer import TokenizerBase
2123
from ..llmapi.tracer import global_tracer
22-
from ..llmapi.utils import _SyncQueue, logger_debug
24+
from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug
2325
from ..lora_manager import LoraManager
2426
from ..metrics import RequestEventTiming
2527
from ..prompt_adapter_manager import PromptAdapterManager
@@ -92,13 +94,44 @@ def __init__(
9294
if global_mpi_size() > 1:
9395
logger.set_rank(self.global_rank)
9496

97+
def _configure_affinity(self, device_id):
98+
'''
99+
Probe and configure the affinity of the worker
100+
'''
101+
102+
# Get the current affinity setting
103+
pid = os.getpid()
104+
process = psutil.Process(pid)
105+
cpu_affinity = process.cpu_affinity()
106+
107+
all_cpus = list(range(psutil.cpu_count()))
108+
109+
constrained_affinity = (cpu_affinity != all_cpus)
110+
111+
# If the process is affined to a constrained set of CPUs, warn the user
112+
# so as to ensure that this is what is intended
113+
if constrained_affinity:
114+
logger.warning(
115+
f"Worker process {pid} is affined to run on the following CPUs: "
116+
"{cpu_affinity} (subset of all logical CPUs). This may harm "
117+
"performance if set incorrectly.", )
118+
119+
# If affinity is unconstrained or the user has explicitly requested it,
120+
# choose the optimal affinity based upon the NUMA topology
121+
if not constrained_affinity or os.environ.get(
122+
"TLLM_NUMA_AWARE_WORKER_AFFINITY", "0") == "1":
123+
process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))
124+
95125
def _get_comm_ranks_device_id(self):
96126
device_id = self.global_rank % torch.cuda.device_count()
97127
torch.cuda.set_device(device_id)
98128
# Make sure C++ executor would use same devices/ranks as py_executor
99129
global_rank = global_mpi_rank()
100130
comm_ranks = mpi_comm().allgather(global_rank)
101131
device_ids = mpi_comm().allgather(device_id)
132+
133+
self._configure_affinity(device_id)
134+
102135
return comm_ranks, device_ids
103136

104137
def setup_engine(self):

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def _get_comm_ranks_device_id(self):
194194

195195
torch.distributed.all_gather_object(comm_ranks, global_rank)
196196
torch.distributed.all_gather_object(device_ids, self.device_id)
197+
198+
self._configure_affinity(self.device_id)
199+
197200
return comm_ranks, device_ids
198201

199202
def enqueue_request(self,

tensorrt_llm/executor/worker.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from ..llmapi.mpi_session import set_mpi_session_cpp
1919
from ..llmapi.tokenizer import TokenizerBase
2020
from ..llmapi.tracer import VizTracer, set_global_tracer
21-
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
22-
clear_sched_affinity, logger_debug,
21+
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug,
2322
print_traceback_on_error)
2423
from ..sampling_params import BatchedLogitsProcessor
2524
from .base_worker import BaseWorker
@@ -245,15 +244,6 @@ def worker_main(
245244
mpi_comm().barrier()
246245
logger_debug(f"Worker {mpi_rank()} entering worker_main...\n", "green")
247246

248-
pid = os.getpid()
249-
cpus = os.sched_getaffinity(pid)
250-
if cpus:
251-
logger.warning(
252-
f"Found worker process {pid} was bound to {cpus}, this may harm "
253-
"performance.", )
254-
logger.warning(f"Will clear the cpu affinity")
255-
clear_sched_affinity(pid)
256-
257247
result_queue: Optional[IpcQueue] = None
258248
result_queues: Optional[List[IpcQueue]] = None
259249

tensorrt_llm/llmapi/utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import collections
3+
import ctypes
34
import datetime
45
import hashlib
56
import inspect
67
import io
8+
import math
79
import os
810
import re
911
import sys
@@ -513,24 +515,38 @@ def get(self, timeout=None):
513515
time.sleep(0.01)
514516

515517

516-
def set_sched_setaffinity(required_cores: int):
517-
''' Set the CPU affinity of the current process to the required number of
518-
cores.
519-
520-
Known issue: This may race with other processes that also set the affinity.
518+
def get_numa_aware_cpu_affinity(device_id):
519+
'''
520+
Given the CUDA device_id, query NVML and return the ideal CPU affinity (a
521+
list of CPU ids) based upon NUMA topology
521522
'''
522-
cpu_percentages = psutil.cpu_percent(percpu=True)
523-
# sort the cores by usage
524-
free_cores = sorted(range(len(cpu_percentages)),
525-
key=lambda i: cpu_percentages[i])
523+
cpu_count = psutil.cpu_count()
524+
525+
# Get the number of bits per ulong
526+
c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8
527+
528+
# Determine how large our cpu set array from NVML needs to be
529+
cpu_set_size = math.ceil(cpu_count / c_ulong_bits)
530+
531+
# initialize NVML
532+
import pynvml
533+
pynvml.nvmlInit()
534+
535+
# Get the Ideal CPU affinity for this device
536+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
537+
affinity_masks = pynvml.nvmlDeviceGetCpuAffinity(handle, cpu_set_size)
526538

527-
pid = os.getpid()
528-
os.sched_setaffinity(pid, set(free_cores[:required_cores]))
539+
# Convert CPU masks to python list
540+
cpu_affinity = []
541+
for cpu_id in range(cpu_count):
542+
mask_array_index = cpu_id // c_ulong_bits
543+
mask_bit_index = cpu_id % c_ulong_bits
544+
if affinity_masks[mask_array_index] & (1 << mask_bit_index):
545+
cpu_affinity.append(cpu_id)
529546

547+
pynvml.nvmlShutdown()
530548

531-
def clear_sched_affinity(pid: int):
532-
''' Clear the CPU affinity of the current process. '''
533-
os.sched_setaffinity(pid, set(range(psutil.cpu_count())))
549+
return cpu_affinity
534550

535551

536552
def generate_api_docs_as_docstring(model: Type[BaseModel],

0 commit comments

Comments
 (0)