Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 75 additions & 115 deletions install_nixl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NIXL_DIR = os.path.join('/tmp', 'nixl_source')
UCX_INSTALL_DIR = os.path.join('/tmp', 'ucx_install')
LIBFABRIC_INSTALL_DIR = os.path.join('/tmp', 'libfabric_install')
NIXL_INSTALL_DIR = os.path.join('/tmp', 'nixl_install')

# --- Repository and Version Configuration ---
UCX_REPO_URL = 'https://github.com/openucx/ucx.git'
Expand Down Expand Up @@ -81,6 +82,77 @@ def install_system_dependencies():
run_command(['apt-get', 'install', '-y'] + apt_packages)
print("--- System dependencies installed successfully. ---\n", flush=True)

def install_nixl():
# Save original directory
original_cwd = os.getcwd()
# Set environment variables
os.environ["LIBUCX_ROOT"] = UCX_INSTALL_DIR
os.environ["LIBNIXL_ROOT"] = NIXL_INSTALL_DIR
os.environ["LIBFABRIC_ROOT"] = LIBFABRIC_INSTALL_DIR

os.environ["PKG_CONFIG_PATH"] = (
f"{os.environ['LIBFABRIC_ROOT']}/lib/pkgconfig:"
f"{os.environ['LIBUCX_ROOT']}/lib/pkgconfig:"
f"{os.environ['LIBNIXL_ROOT']}/lib/pkgconfig:"
+ os.environ.get("PKG_CONFIG_PATH", "")
)

os.environ["CPLUS_INCLUDE_PATH"] = (
f"{os.environ['LIBNIXL_ROOT']}/include:" +
os.environ.get("CPLUS_INCLUDE_PATH", "")
)

os.environ["C_INCLUDE_PATH"] = (
f"{os.environ['LIBNIXL_ROOT']}/include:" +
os.environ.get("C_INCLUDE_PATH", "")
)

os.environ["LDFLAGS"] = f"-L{os.environ['LIBNIXL_ROOT']}/lib " + os.environ.get("LDFLAGS", "")
os.environ["LD_LIBRARY_PATH"] = (
f"{os.environ['LIBNIXL_ROOT']}/lib:" +
os.environ.get("LD_LIBRARY_PATH", "")
)

try:

# Change directory
os.chdir(NIXL_DIR)

# Run pip installs
subprocess.run(["pip", "install", "--upgrade", "meson", "pybind11", "patchelf"], check=True)
subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)

# Meson setup
subprocess.run([
"meson", "setup",
"--wipe",
f"--prefix={os.environ['LIBNIXL_ROOT']}",
"--buildtype=release",
"-Ddisable_gds_backend=true",
f"-Dlibfabric_path={os.environ['LIBFABRIC_ROOT']}",
f"-Ducx_path={os.environ['LIBUCX_ROOT']}",
"builddir", "."
], check=True)

# Build and install
os.chdir("builddir")
subprocess.run(["ninja"], check=True)
subprocess.run(["ninja", "install"], check=True)
subprocess.run(["ldconfig"], check=True)
os.chdir("..")
# Install python package
subprocess.run(["pip", "install", "."], check=True)
except subprocess.CalledProcessError as e:
print(f"Command failed with exit code {e.returncode}: {e.cmd}")
raise

except Exception as e:
print(f"Unexpected error: {e}")
raise

finally:
# Return to original directory
os.chdir(original_cwd)

def build_and_install_prerequisites(args):
"""Builds UCX and NIXL from source, creating a self-contained wheel."""
Expand All @@ -100,8 +172,6 @@ def build_and_install_prerequisites(args):
return

print("\n--> No installed package or cached wheel found. Starting full build process...", flush=True)
print("\n--> Installing auditwheel...", flush=True)
run_command([sys.executable, '-m', 'pip', 'install', 'auditwheel'])
install_system_dependencies()
ucx_install_path = os.path.abspath(UCX_INSTALL_DIR)
print(f"--> Using wheel cache directory: {WHEELS_CACHE_HOME}", flush=True)
Expand Down Expand Up @@ -132,7 +202,7 @@ def build_and_install_prerequisites(args):
print("--- UCX build and install complete ---", flush=True)

# -- Step 2: Build Libfabric from source --
print(f"\n[2/4] Configuring and building Libfabric (ref: {LIBFABRIC_REF}) from source...", flush=True)
print(f"\n[2/3] Configuring and building Libfabric (ref: {LIBFABRIC_REF}) from source...", flush=True)
if not os.path.exists(LIBFABRIC_DIR):
run_command(['git', 'clone', LIBFABRIC_REPO_URL, LIBFABRIC_DIR])
run_command(['git', 'checkout', LIBFABRIC_REF], cwd=LIBFABRIC_DIR)
Expand All @@ -150,120 +220,10 @@ def build_and_install_prerequisites(args):


# -- Step 3: Build NIXL wheel from source --
print(f"\n[3/4] Building NIXL (branch: {NIXL_BRANCH}) wheel from source...", flush=True)
print(f"\n[3/3] Building NIXL (branch: {NIXL_BRANCH}) wheel from source...", flush=True)
if not os.path.exists(NIXL_DIR):
run_command(['git', 'clone', '--branch', NIXL_BRANCH, NIXL_REPO_URL, NIXL_DIR])

build_env = os.environ.copy()
# Configure environment to find both UCX and Libfabric
ucx_install_path = os.path.abspath(UCX_INSTALL_DIR)
lf_install_path = os.path.abspath(LIBFABRIC_INSTALL_DIR)

ucx_pkg_path = os.path.join(ucx_install_path, 'lib', 'pkgconfig')
lf_pkg_path = os.path.join(lf_install_path, 'lib', 'pkgconfig')
build_env['PKG_CONFIG_PATH'] = f"{ucx_pkg_path}:{lf_pkg_path}".strip(':')

ucx_lib_path = os.path.join(ucx_install_path, 'lib')
ucx_plugin_path = os.path.join(ucx_lib_path, 'ucx')
lf_lib_path = os.path.join(lf_install_path, 'lib')
build_env['LD_LIBRARY_PATH'] = f"{ucx_lib_path}:{ucx_plugin_path}:{lf_lib_path}".strip(':')

print(f"--> Using PKG_CONFIG_PATH: {build_env['PKG_CONFIG_PATH']}", flush=True)
print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True)

temp_wheel_dir = os.path.join(ROOT_DIR, 'temp_wheelhouse')
# Define the build command for nixl wheel with specific meson arguments
wheel_build_cmd = [
sys.executable, '-m', 'pip', 'wheel', '.',
'--no-deps',
f'--wheel-dir={temp_wheel_dir}',
# Pass meson arguments via pip's config-settings
'--config-settings=setup-args=-Ddisable_gds_backend=true',
f'--config-settings=setup-args=-Dlibfabric_path={lf_install_path}',
f'--config-settings=setup-args=-Ducx_path={ucx_install_path}',
]

run_command(wheel_build_cmd,
cwd=os.path.abspath(NIXL_DIR),
env=build_env)

# -- Step 4: Repair wheel, then replace libfabric --
# auditwheel may bundle an incompatible libfabric, so we need to replace it
print("\n[4/4] Repairing wheel with auditwheel and correcting libfabric...", flush=True)
unrepaired_wheel = find_nixl_wheel_in_cache(temp_wheel_dir)
if not unrepaired_wheel: raise RuntimeError("Failed to find the NIXL wheel after building it.")

# First, run auditwheel to bundle all other dependencies
run_command([sys.executable, '-m', 'auditwheel', 'repair', '--exclude', 'libplugin_UCX.so', unrepaired_wheel, f'--wheel-dir={WHEELS_CACHE_HOME}'], env=build_env)

repaired_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME)
if not repaired_wheel: raise RuntimeError("Failed to find repaired wheel from auditwheel.")

# Now, unpack the repaired wheel to perform surgery on it
wheel_unpack_dir = os.path.join(temp_wheel_dir, "wheel_unpack")
if os.path.exists(wheel_unpack_dir): shutil.rmtree(wheel_unpack_dir)
os.makedirs(wheel_unpack_dir)
run_command(['unzip', '-q', repaired_wheel, '-d', wheel_unpack_dir])

# Find the main NIXL extension file to inspect its dependencies
nixl_extension_search = glob.glob(os.path.join(wheel_unpack_dir, "nixl", "*.so"))
if not nixl_extension_search: raise RuntimeError("Could not find main NIXL .so extension file.")
nixl_extension_file = nixl_extension_search[0]

# Find the .libs directory
libs_dir_search = glob.glob(os.path.join(wheel_unpack_dir, "*.libs"))
if not libs_dir_search: raise RuntimeError("Could not find .libs directory in unpacked wheel.")
libs_dir = libs_dir_search[0]

# Find the incorrect libfabric that auditwheel bundled
incorrect_lib_basename = None
for lib in os.listdir(libs_dir):
if 'libfabric' in lib:
incorrect_lib_basename = lib
break

# Only perform replacement if we found a library to replace
if incorrect_lib_basename:
incorrect_lib_path = os.path.join(libs_dir, incorrect_lib_basename)
print(f"--> Found and deleting incorrect bundled library: {incorrect_lib_basename}", flush=True)
os.remove(incorrect_lib_path)

# Find the correct, pre-built libfabric library
lf_lib_path = os.path.join(lf_install_path, 'lib')
libfabric_so_files = glob.glob(os.path.join(lf_lib_path, 'libfabric.so.1.*'))
if not libfabric_so_files: raise RuntimeError(f"Could not find libfabric.so.1.* in {lf_lib_path}")
correct_libfabric_src = max(libfabric_so_files, key=len)
correct_libfabric_basename = os.path.basename(correct_libfabric_src)

# Copy it into the wheel's .libs directory
print(f"--> Copying correct library '{correct_libfabric_basename}' into wheel", flush=True)
shutil.copy2(correct_libfabric_src, os.path.join(libs_dir, incorrect_lib_path))

# Use patchelf to update the dependency link in the main NIXL extension
# print(f"--> Patching NIXL extension to link against '{correct_libfabric_basename}'", flush=True)
# run_command(['patchelf', '--replace-needed', incorrect_lib_basename, correct_libfabric_basename, nixl_extension_file])
else:
print("--> Warning: Did not find a bundled libfabric to remove. It might have been excluded.", flush=True)

# Repack the corrected wheel, overwriting the one from auditwheel
print(f"--> Repacking corrected wheel to '{os.path.basename(repaired_wheel)}'", flush=True)
run_command(['zip', '-r', repaired_wheel, '.'], cwd=wheel_unpack_dir)

# --- Cleanup ---
shutil.rmtree(temp_wheel_dir)

# --- Final Installation ---
newly_built_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME)
if not newly_built_wheel:
raise RuntimeError("Failed to find the repaired NIXL wheel.")

print(f"--> Successfully built self-contained wheel: {os.path.basename(newly_built_wheel)}. Now installing...",
flush=True)
install_command = [sys.executable, '-m', 'pip', 'install', newly_built_wheel]
if args.force_reinstall:
install_command.insert(-1, '--force-reinstall')

run_command(install_command)
install_nixl()
print("--- NIXL installation complete ---", flush=True)


Expand Down
19 changes: 15 additions & 4 deletions nixl_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test(sender_meta: NixlAgentMetadata, agent: NixlAgent, local_xfer_handle: st
num_iterations = args.num_iterations if not is_warmup else 1
logging.info(f"Starting transfer loop for {num_iterations} iterations...")
logging.info(f"Each iteration transfers {args.blocks_per_xfer} blocks ({data_per_iteration / 1e6:.2f} MB)")
logging.info(f"Do extra h2d copy: {args.do_h2d_cp}")

# Measure total elapsed time for sustained throughput calculation
start_time = time.perf_counter()
Expand All @@ -89,9 +90,11 @@ def test(sender_meta: NixlAgentMetadata, agent: NixlAgent, local_xfer_handle: st
end = start + args.blocks_per_xfer
block_ids = list(range(start, end))
iteration_history.append((start, end))

# Transfer and measure time
latency_ms, data_transferred = read_blocks(block_ids, agent, local_xfer_handle, remote_xfer_handle, sender_meta)


# Transfer and measure time, h2d copy operation will enabled when do_h2d_cp=True
latency_ms, data_transferred = read_blocks(block_ids, agent, local_xfer_handle, remote_xfer_handle, sender_meta,
args.do_h2d_cp, local_kv_cache, start, end)

# Verify data_transferred matches expected amount
assert data_transferred == data_per_iteration, f"Data mismatch: expected {data_per_iteration}, got {data_transferred}"
Expand Down Expand Up @@ -161,7 +164,8 @@ def create_xfer_descs(agent: NixlAgent, base_addr: int, num_blocks: int, block_l


def read_blocks(block_ids: Iterator[int], agent: NixlAgent,
local_xfer_handle: str, remote_xfer_handle: str, sender_meta: NixlAgentMetadata):
local_xfer_handle: str, remote_xfer_handle: str, sender_meta: NixlAgentMetadata,
do_h2d_cp, local_kv_cache, start, end):
""" Read blocks from the sender's KV cache using NIXL. """
if not block_ids:
logging.warning("No block IDs provided for transfer.")
Expand Down Expand Up @@ -191,7 +195,12 @@ def read_blocks(block_ids: Iterator[int], agent: NixlAgent,
time.sleep(0.00001)
agent.release_xfer_handle(xfer_handle)
del xfer_handle

if do_h2d_cp:
local_kv_cache[start:end].to("hpu")
t1 = time.perf_counter_ns()



return (t1 - t0) / 1e6, len(local_ids) * sender_meta.block_len # Return latency in ms

Expand Down Expand Up @@ -490,6 +499,8 @@ def receiver_process(args: argparse.Namespace, zmq_host: str = "127.0.0.1"):
help="Role to run: sender or receiver")
parser.add_argument("--zmq-host", type=str, default="127.0.0.1",
help="ZMQ host address (use sender's IP for receiver)")
parser.add_argument("--do-h2d-cp", action="store_true",
help="Do extra h2d copy")
args = parser.parse_args()
#FI_LOG_LEVEL=debug NIXL_LOG_LEVEL=debug
# PT_HPU_POOL_STRATEGY=0 NIXL_PLUGIN_DIR=/workspace/nixl/nixl-nixl_libfabric/build/cp310/src/plugins/libfabric python ts_nixl/nixl_api.py --device-type hpu --nixl_backend libfabric --nixl-memory-type DRAM
Expand Down