diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index bb2b002b1a17..1e82445040d4 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -50,5 +50,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.4" - HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.4" + DS_UNITTEST_MASTER_PORT_LOCK_FILE=/tmp/master_ports.lock HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.4" + DS_UNITTEST_MASTER_PORT_LOCK_FILE=/tmp/master_ports.lock HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.4" diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index e888c472638f..7c9792334f8a 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -55,5 +55,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.4" --cuda_ver="12.1" - pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.4" --cuda_ver="12.1" + DS_UNITTEST_MASTER_PORT_LOCK_FILE=/tmp/master_ports.lock pytest $PYTEST_OPTS -s unit/ --torch_ver="2.4" --cuda_ver="12.1" 2>/dev/null | grep MEM_DEBUG + DS_UNITTEST_MASTER_PORT_LOCK_FILE=/tmp/master_ports.lock pytest $PYTEST_OPTS -s -m 'sequential' unit/ --torch_ver="2.4" --cuda_ver="12.1" 2>/dev/null | grep MEM_DEBUG diff --git a/tests/conftest.py b/tests/conftest.py index 45e8434a021b..6d26e00a333c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ import torch import warnings +from unit.common import release_port_with_lock + # Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small) os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' @@ -76,7 +78,8 @@ def pytest_runtest_call(item): def pytest_runtest_teardown(item, nextitem): if getattr(item.cls, "reuse_dist_env", False) and not nextitem: dist_test_class = item.cls() - for num_procs, pool in dist_test_class._pool_cache.items(): + for num_procs, (pool, master_port) in dist_test_class._pool_cache.items(): + release_port_with_lock(int(master_port)) dist_test_class._close_pool(pool, num_procs, force=True) diff --git a/tests/unit/common.py b/tests/unit/common.py index 69ba4c2708ac..9bb7c8fcd63c 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -11,6 +11,8 @@ import subprocess from abc import ABC, abstractmethod from pathlib import Path +import fcntl +import tempfile import torch import torch.multiprocessing as mp @@ -24,6 +26,12 @@ # Worker timeout for tests that hang DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) +DEEPSPEED_MASTER_PORT_LOCK_FILE = os.environ.get('DS_UNITTEST_MASTER_PORT_LOCK_FILE', None) + +import logging + +logger = mp.log_to_stderr() +logger.setLevel(logging.INFO) warn_reuse_dist_env = False @@ -41,23 +49,60 @@ def get_xdist_worker_id(): def get_master_port(base_port=29500, port_range_size=1000): - xdist_worker_id = get_xdist_worker_id() - if xdist_worker_id is not None: - # Make xdist workers use different port ranges to avoid race conditions - base_port += port_range_size * xdist_worker_id - - # Select first open port in range - port = base_port - max_port = base_port + port_range_size - sock = socket.socket() - while port < max_port: + global DEEPSPEED_MASTER_PORT_LOCK_FILE + + if DEEPSPEED_MASTER_PORT_LOCK_FILE is None: + # Generate file name only + fd, DEEPSPEED_MASTER_PORT_LOCK_FILE = tempfile.mkstemp() + os.close(fd) + + available_ports = list(range(base_port, base_port + port_range_size)) + + with open(DEEPSPEED_MASTER_PORT_LOCK_FILE, 'a+') as port_file: + try: + fcntl.flock(port_file, fcntl.LOCK_EX) + port_file.seek(0) + used_ports = {int(line.strip()) for line in port_file if line.strip().isdigit()} + + sock = socket.socket() + for port in available_ports: + if port not in used_ports: + try: + sock.bind(('', port)) + sock.close() + + port_file.write(f"{port}\n") + port_file.flush() + return str(port) + except OSError: + pass + raise IOError('no free ports') + + finally: + fcntl.flock(port_file, fcntl.LOCK_UN) + + +def release_port_with_lock(port): + if not os.path.exists(DEEPSPEED_MASTER_PORT_LOCK_FILE): + raise FileNotFoundError(f"Port file not found: {DEEPSPEED_MASTER_PORT_LOCK_FILE}") + + with open(DEEPSPEED_MASTER_PORT_LOCK_FILE, 'r+') as port_file: try: - sock.bind(('', port)) - sock.close() - return str(port) - except OSError: - port += 1 - raise IOError('no free ports') + fcntl.flock(port_file, fcntl.LOCK_EX) + lines = port_file.readlines() + port_file.seek(0) + port_file.truncate(0) + + for line in lines: + if int(line.strip()) != port: + port_file.write(line) + + port_file.seek(0) + if port_file.read().strip() == "": + os.remove(DEEPSPEED_MASTER_PORT_LOCK_FILE) + + finally: + fcntl.flock(port_file, fcntl.LOCK_UN) def _get_cpu_socket_count(): @@ -176,10 +221,13 @@ def _launch_daemonic_procs(self, num_procs): # Create process pool or use cached one master_port = None - if get_accelerator().device_name() == 'hpu': - if self.reuse_dist_env: - print("Ignoring reuse_dist_env for hpu") - self.reuse_dist_env = False + # if get_accelerator().device_name() == 'hpu': + # if self.reuse_dist_env: + # print("Ignoring reuse_dist_env for hpu") + # self.reuse_dist_env = False + + print("[MEM_DEBUG] Ignoring reuse_dist_env and forcibly setting it to False") + self.reuse_dist_env = False global warn_reuse_dist_env if self.reuse_dist_env and not warn_reuse_dist_env: @@ -190,9 +238,9 @@ def _launch_daemonic_procs(self, num_procs): if self.reuse_dist_env: if num_procs not in self._pool_cache: - self._pool_cache[num_procs] = mp.Pool(processes=num_procs) master_port = get_master_port() - pool = self._pool_cache[num_procs] + self._pool_cache[num_procs] = (mp.Pool(processes=num_procs), master_port) + pool, _ = self._pool_cache[num_procs] else: pool = mp.Pool(processes=num_procs) master_port = get_master_port() @@ -212,6 +260,8 @@ def _launch_daemonic_procs(self, num_procs): # Regardless of the outcome, ensure proper teardown # Tear down distributed environment and close process pools self._close_pool(pool, num_procs) + if not self.reuse_dist_env: + release_port_with_lock(int(master_port)) # If we skipped a test, propagate that to this process if any(skip_msgs): @@ -221,53 +271,56 @@ def _launch_daemonic_procs(self, num_procs): def _launch_non_daemonic_procs(self, num_procs): assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" - master_port = get_master_port() - skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason - processes = [] - prev_start_method = mp.get_start_method() - mp.set_start_method('spawn', force=True) - for local_rank in range(num_procs): - p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) - p.start() - processes.append(p) - mp.set_start_method(prev_start_method, force=True) - - # Now loop and wait for a test to complete. The spin-wait here isn't a big - # deal because the number of processes will be O(#GPUs) << O(#CPUs). - any_done = False - start = time.time() - while (not any_done) and ((time.time() - start) < self.exec_timeout): - for p in processes: - if not p.is_alive(): - any_done = True - break - time.sleep(.1) # So we don't hog CPU - - # If we hit the timeout, then presume a test is hanged - if not any_done: + try: + master_port = get_master_port() + skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason + processes = [] + prev_start_method = mp.get_start_method() + mp.set_start_method('spawn', force=True) + for local_rank in range(num_procs): + p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) + p.start() + processes.append(p) + mp.set_start_method(prev_start_method, force=True) + + # Now loop and wait for a test to complete. The spin-wait here isn't a big + # deal because the number of processes will be O(#GPUs) << O(#CPUs). + any_done = False + start = time.time() + while (not any_done) and ((time.time() - start) < self.exec_timeout): + for p in processes: + if not p.is_alive(): + any_done = True + break + time.sleep(.1) # So we don't hog CPU + + # If we hit the timeout, then presume a test is hanged + if not any_done: + for p in processes: + p.terminate() + pytest.exit("Test hanged, exiting", returncode=1) + + # Wait for all other processes to complete for p in processes: - p.terminate() - pytest.exit("Test hanged, exiting", returncode=1) - - # Wait for all other processes to complete - for p in processes: - p.join(self.exec_timeout) - - failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] - for rank, p in failed: - # If it still hasn't terminated, kill it because it hung. - if p.exitcode is None: - p.terminate() - pytest.fail(f'Worker {rank} hung.', pytrace=False) - if p.exitcode < 0: - pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False) - if p.exitcode > 0: - pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False) - - if not skip_msg.empty(): - # This assumed all skip messages are the same, it may be useful to - # add a check here to assert all exit messages are equal - pytest.skip(skip_msg.get()) + p.join(self.exec_timeout) + + failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] + for rank, p in failed: + # If it still hasn't terminated, kill it because it hung. + if p.exitcode is None: + p.terminate() + pytest.fail(f'Worker {rank} hung.', pytrace=False) + if p.exitcode < 0: + pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False) + if p.exitcode > 0: + pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False) + + if not skip_msg.empty(): + # This assumed all skip messages are the same, it may be useful to + # add a check here to assert all exit messages are equal + pytest.skip(skip_msg.get()) + finally: + release_port_with_lock(int(master_port)) def _launch_procs(self, num_procs): # Verify we have enough accelerator devices to run this test @@ -283,6 +336,39 @@ def _launch_procs(self, num_procs): # Set start method to `forkserver` (or `fork`) mp.set_start_method('forkserver', force=True) + def print_device_memory_usage(): + import pynvml + + # Get the number of GPUs + device_count = pynvml.nvmlDeviceGetCount() + + # Iterate over each GPU and print memory usage + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + name = pynvml.nvmlDeviceGetName(handle) + print( + f"[MEM_DEBUG] GPU {i}: {name} Total memory: {info.total} Used memory: {info.used} Free memory: {info.free}" + ) + + def print_cpu_memory_usage(): + import psutil + vm_stats = psutil.virtual_memory() + used = vm_stats.total - vm_stats.available + print(f"[MEM_DEBUG] CPU Memory Usage: {used} Available: {vm_stats.available}") + + print(f"[MEM_DEBUG] Running test with {num_procs} processes") + if get_accelerator()._name == 'cuda': + print_device_memory_usage() + print_cpu_memory_usage() + + import getpass + current_user = getpass.getuser() + import psutil + user_process_count = sum(1 for proc in psutil.process_iter(['username']) + if proc.info['username'] == current_user) + print(f"[MEM_DEBUG] User process count: {user_process_count}") + if self.non_daemonic_procs: self._launch_non_daemonic_procs(num_procs) else: @@ -305,6 +391,11 @@ def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): # turn off NCCL logging if set os.environ.pop('NCCL_DEBUG', None) + if "MASTER_ADDR" in os.environ: + print( + f"[MEM_DEBUG] [r{os.environ['RANK']}] MASTER_ADDR: {os.environ['MASTER_ADDR']}, MASTER_PORT: {os.environ['MASTER_PORT']}, LOCAL_RANK: {os.environ['LOCAL_RANK']}, RANK: {os.environ['RANK']}, LOCAL_SIZE: {os.environ['LOCAL_SIZE']}, WORLD_SIZE: {os.environ['WORLD_SIZE']}" + ) + if get_accelerator().is_available(): set_accelerator_visible()