diff --git a/.gitignore b/.gitignore index 2f7896d1..663cefe0 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,6 @@ target/ +CLAUDE.md + +# macOS resource fork files +._* +.DS_Store diff --git a/CHANGELOG.md b/CHANGELOG.md index e55fb949..9d64fbc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] +### Added + +* Add hierarchical process tree view feature #1 - @wezm + ### Fixed * Update CONTRIBUTING information #438 - @YJDoc2 @cyqsimon diff --git a/README.md b/README.md index 538ef7fe..2cda1fe9 100644 --- a/README.md +++ b/README.md @@ -184,10 +184,19 @@ Options: -a, --addresses Show remote addresses table only -u, --unit-family Choose a specific family of units [default: bin-bytes] [possible values: bin-bytes, bin-bits, si-bytes, si-bits] -t, --total-utilization Show total (cumulative) usages + --tree-view Display processes in hierarchical tree view showing parent-child relationships -h, --help Print help (see more with '--help') -V, --version Print version ``` +### Process Tree View + +The `--tree-view` flag enables hierarchical process visualization, showing parent-child relationships between processes. This is particularly useful for understanding which services or daemons are responsible for network activity, and for visualizing complex applications with multiple child processes. When enabled: + +- Processes are displayed in a tree structure with indentation showing the hierarchy +- Bandwidth usage is aggregated up the tree (parent processes show their own usage plus their children's) +- Each process shows its own PID and its relationship to other processes + ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/src/cli.rs b/src/cli.rs index af21c6a9..048027f2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -60,6 +60,10 @@ pub struct RenderOpts { #[arg(short, long)] /// Show total (cumulative) usages pub total_utilization: bool, + + #[arg(long)] + /// Display processes in hierarchical tree view showing parent-child relationships + pub tree_view: bool, } // IMPRV: it would be nice if we can `#[cfg_attr(not(build), derive(strum::EnumIter))]` this diff --git a/src/display/components/table.rs b/src/display/components/table.rs index 4c66d3c4..2b9f5b88 100644 --- a/src/display/components/table.rs +++ b/src/display/components/table.rs @@ -261,7 +261,12 @@ impl Table { pub fn create_processes_table(state: &UIState) -> Self { use DisplayLayout as D; - let title = "Utilization by process name"; + let title = if state.tree_view { + "Utilization by process tree" + } else { + "Utilization by process name" + }; + let width_cutoffs = vec![ (0, D::C2([16, 18])), (50, D::C3([16, 12, 20])), @@ -279,22 +284,50 @@ impl Table { "Rate (Up / Down)" }, ]; - let rows = state - .processes - .iter() - .map(|(proc_info, data_for_process)| { - [ - proc_info.name.to_string(), - proc_info.pid.to_string(), - data_for_process.connection_count.to_string(), - display_upload_and_download( - data_for_process, - state.unit_family, - state.cumulative_mode, - ), - ] - }) - .collect(); + + let rows = if state.tree_view { + // Build hierarchical process list from process trees using aggregated data + let mut tree_rows = Vec::new(); + for tree in &state.process_trees { + for (proc_info, depth) in tree.iter_depth_first() { + // Use aggregated data which includes children's bandwidth + if let Some(data_for_process) = state.aggregated_processes_map.get(proc_info) { + let indent = " ".repeat(depth); + let process_name = format!("{}{}", indent, proc_info.name); + tree_rows.push([ + process_name, + proc_info.pid.to_string(), + data_for_process.connection_count.to_string(), + display_upload_and_download( + data_for_process, + state.unit_family, + state.cumulative_mode, + ), + ]); + } + } + } + tree_rows + } else { + // Regular flat process list + state + .processes + .iter() + .map(|(proc_info, data_for_process)| { + [ + proc_info.name.to_string(), + proc_info.pid.to_string(), + data_for_process.connection_count.to_string(), + display_upload_and_download( + data_for_process, + state.unit_family, + state.cumulative_mode, + ), + ] + }) + .collect() + }; + let column_selector = Rc::new(|layout: &D| match layout { D::C2(_) => vec![0, 3], D::C3(_) => vec![0, 2, 3], diff --git a/src/display/mod.rs b/src/display/mod.rs index 127c2bc1..2462254d 100644 --- a/src/display/mod.rs +++ b/src/display/mod.rs @@ -1,3 +1,10 @@ +//! Terminal user interface components +//! +//! This module provides the display functionality for bandwhich: +//! - Terminal UI rendering using ratatui +//! - Raw output mode for piping to other programs +//! - UI state management and component rendering + mod components; mod raw_terminal_backend; mod ui; diff --git a/src/display/ui.rs b/src/display/ui.rs index 0f94cb67..140227db 100644 --- a/src/display/ui.rs +++ b/src/display/ui.rs @@ -28,15 +28,16 @@ where B: Backend, { pub fn new(terminal_backend: B, opts: &Opt) -> Self { - let mut terminal = Terminal::new(terminal_backend).unwrap(); - terminal.clear().unwrap(); - terminal.hide_cursor().unwrap(); + let mut terminal = Terminal::new(terminal_backend).expect("Failed to create terminal"); + terminal.clear().expect("Failed to clear terminal"); + terminal.hide_cursor().expect("Failed to hide cursor"); let state = { let mut state = UIState::default(); state.interface_name.clone_from(&opts.interface); state.unit_family = opts.render_opts.unit_family.into(); state.cumulative_mode = opts.render_opts.total_utilization; state.show_dns = opts.show_dns; + state.tree_view = opts.render_opts.tree_view; state }; Ui { @@ -140,9 +141,9 @@ where show_dns: self.state.show_dns, }, }; - self.terminal - .draw(|frame| layout.render(frame, frame.area(), table_cycle_offset)) - .unwrap(); + let _ = self + .terminal + .draw(|frame| layout.render(frame, frame.area(), table_cycle_offset)); } fn get_tables_to_display(&self) -> Vec { @@ -187,6 +188,6 @@ where self.ip_to_host.extend(ip_to_host); } pub fn end(&mut self) { - self.terminal.show_cursor().unwrap(); + let _ = self.terminal.show_cursor(); } } diff --git a/src/display/ui_state.rs b/src/display/ui_state.rs index a32a6dac..ee950173 100644 --- a/src/display/ui_state.rs +++ b/src/display/ui_state.rs @@ -10,7 +10,7 @@ use log::warn; use crate::{ display::BandwidthUnitFamily, network::{Connection, LocalSocket, Utilization}, - os::ProcessInfo, + os::{aggregate_bandwidth_by_tree, build_process_trees, ProcessInfo, ProcessTreeNode}, }; static RECALL_LENGTH: usize = 5; @@ -89,11 +89,14 @@ pub struct UIState { pub total_bytes_uploaded: u128, pub cumulative_mode: bool, pub show_dns: bool, + pub tree_view: bool, pub unit_family: BandwidthUnitFamily, pub utilization_data: VecDeque, pub processes_map: HashMap, pub remote_addresses_map: HashMap, pub connections_map: HashMap, + pub process_trees: Vec, + pub aggregated_processes_map: HashMap, /// Used for reducing logging noise. known_orphan_sockets: VecDeque, } @@ -224,6 +227,16 @@ impl UIState { self.processes = sort_and_prune(&mut self.processes_map); self.remote_addresses = sort_and_prune(&mut self.remote_addresses_map); self.connections = sort_and_prune(&mut self.connections_map); + + // Build process trees if tree view is enabled + if self.tree_view { + let all_processes: Vec = self.processes_map.keys().cloned().collect(); + self.process_trees = build_process_trees(all_processes); + + // Aggregate bandwidth by process tree + self.aggregated_processes_map = + aggregate_bandwidth_by_tree(&self.process_trees, &self.processes_map); + } } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..dadcc83e --- /dev/null +++ b/src/error.rs @@ -0,0 +1,56 @@ +//! Custom error types for bandwhich +//! +//! This module provides structured error handling throughout the application, +//! replacing generic error types with domain-specific ones for better +//! error messages and recovery strategies. + +use std::io; +use thiserror::Error; + +/// Main error type for bandwhich operations +#[derive(Debug, Error)] +pub enum BandwhichError { + /// Terminal initialization or operation failed + #[error("Terminal error: {0}")] + Terminal(#[from] io::Error), + + /// Thread spawning or joining failed + #[error("Thread error: {0}")] + #[allow(dead_code)] + Thread(String), + + /// Lock acquisition failed (poisoned mutex) + #[error("Lock poisoned: {0}")] + LockPoisoned(String), + + /// DNS resolution error + #[error("DNS error: {0}")] + #[allow(dead_code)] + Dns(String), + + /// Network interface error + #[error("Network interface error: {0}")] + #[allow(dead_code)] + NetworkInterface(String), + + /// Process information retrieval error + #[error("Process info error: {0}")] + #[allow(dead_code)] + ProcessInfo(String), + + /// Configuration or CLI argument error + #[error("Configuration error: {0}")] + #[allow(dead_code)] + Config(String), +} + +/// Result type alias for bandwhich operations +#[allow(dead_code)] +pub type Result = std::result::Result; + +/// Convert from std::sync::PoisonError to BandwhichError +impl From> for BandwhichError { + fn from(err: std::sync::PoisonError) -> Self { + BandwhichError::LockPoisoned(format!("Mutex poisoned: {err}")) + } +} diff --git a/src/main.rs b/src/main.rs index 8ace86fb..9ba6023e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,29 @@ +//! Bandwhich - Terminal bandwidth utilization tool +//! +//! This is the main entry point for bandwhich, a CLI utility for displaying +//! current network utilization by process, connection and remote IP/hostname. +//! +//! # Architecture +//! +//! The application uses a multi-threaded architecture with three main components: +//! +//! 1. **Display Handler Thread**: Updates the terminal UI every second with current +//! network statistics. Handles both TUI and raw output modes. +//! +//! 2. **Terminal Event Handler Thread**: Processes keyboard input for interactive +//! controls (pause/resume, quit, tab to switch views). +//! +//! 3. **Sniffer Threads**: One per network interface, continuously captures packets +//! and updates shared network utilization statistics. +//! +//! All threads communicate through shared state protected by Arc> and +//! Arc> for thread-safe access. + #![deny(clippy::enum_glob_use)] mod cli; mod display; +mod error; mod network; mod os; #[cfg(test)] @@ -36,6 +58,7 @@ use simplelog::WriteLogger; use crate::cli::Opt; use crate::os::ProcessInfo; +/// Refresh interval for the display thread - updates UI every second const DISPLAY_DELTA: Duration = Duration::from_millis(1000); fn main() -> eyre::Result<()> { @@ -111,19 +134,22 @@ where let display_handler = thread::Builder::new() .name("display_handler".to_string()) .spawn({ - let running = running.clone(); - let paused = paused.clone(); - let table_cycle_offset = table_cycle_offset.clone(); + let running = Arc::clone(&running); + let paused = Arc::clone(&paused); + let table_cycle_offset = Arc::clone(&table_cycle_offset); - let network_utilization = network_utilization.clone(); - let last_start_time = last_start_time.clone(); - let cumulative_time = cumulative_time.clone(); - let ui = ui.clone(); + let network_utilization = Arc::clone(&network_utilization); + let last_start_time = Arc::clone(&last_start_time); + let cumulative_time = Arc::clone(&cumulative_time); + let ui = Arc::clone(&ui); move || { while running.load(Ordering::Acquire) { let render_start_time = Instant::now(); - let utilization = network_utilization.lock().unwrap().clone_and_reset(); + let utilization = network_utilization + .lock() + .expect("network_utilization lock poisoned") + .clone_and_reset(); let OpenSockets { sockets_to_procs } = get_open_sockets(); let mut ip_to_host = IpTable::new(); if let Some(dns_client) = dns_client.as_mut() { @@ -137,15 +163,19 @@ where dns_client.resolve(unresolved_ips); } { - let mut ui = ui.lock().unwrap(); + let mut ui = ui.lock().expect("ui lock poisoned"); let paused = paused.load(Ordering::SeqCst); let table_cycle_offset = table_cycle_offset.load(Ordering::SeqCst); if !paused { ui.update_state(sockets_to_procs, utilization, ip_to_host); } let elapsed_time = elapsed_time( - *last_start_time.read().unwrap(), - *cumulative_time.read().unwrap(), + *last_start_time + .read() + .expect("last_start_time read lock poisoned"), + *cumulative_time + .read() + .expect("cumulative_time read lock poisoned"), paused, ); @@ -161,22 +191,23 @@ where } } if !raw_mode { - let mut ui = ui.lock().unwrap(); - ui.end(); + if let Ok(mut ui) = ui.lock() { + ui.end(); + } } } }) - .unwrap(); + .expect("Failed to spawn display_handler thread"); let terminal_event_handler = thread::Builder::new() .name("terminal_events_handler".to_string()) .spawn({ - let running = running.clone(); + let running = Arc::clone(&running); let display_handler = display_handler.thread().clone(); move || { for evt in terminal_events { - let mut ui = ui.lock().unwrap(); + let mut ui = ui.lock().expect("ui lock poisoned"); match evt { Event::Resize(_x, _y) if !raw_mode => { @@ -184,8 +215,12 @@ where ui.draw( paused, elapsed_time( - *last_start_time.read().unwrap(), - *cumulative_time.read().unwrap(), + *last_start_time + .read() + .expect("last_start_time read lock poisoned"), + *cumulative_time + .read() + .expect("cumulative_time read lock poisoned"), paused, ), table_cycle_offset.load(Ordering::SeqCst), @@ -225,13 +260,22 @@ where }) => { let restarting = paused.fetch_xor(true, Ordering::SeqCst); if restarting { - *last_start_time.write().unwrap() = Instant::now(); + *last_start_time + .write() + .expect("last_start_time write lock poisoned") = Instant::now(); } else { - let last_start_time_copy = *last_start_time.read().unwrap(); - let current_cumulative_time_copy = *cumulative_time.read().unwrap(); + let last_start_time_copy = *last_start_time + .read() + .expect("last_start_time read lock poisoned"); + let current_cumulative_time_copy = *cumulative_time + .read() + .expect("cumulative_time read lock poisoned"); let new_cumulative_time = current_cumulative_time_copy + last_start_time_copy.elapsed(); - *cumulative_time.write().unwrap() = new_cumulative_time; + *cumulative_time + .write() + .expect("cumulative_time write lock poisoned") = + new_cumulative_time; } display_handler.unpark(); @@ -244,8 +288,12 @@ where }) => { let paused = paused.load(Ordering::SeqCst); let elapsed_time = elapsed_time( - *last_start_time.read().unwrap(), - *cumulative_time.read().unwrap(), + *last_start_time + .read() + .expect("last_start_time read lock poisoned"), + *cumulative_time + .read() + .expect("cumulative_time read lock poisoned"), paused, ); let table_count = ui.get_table_count(); @@ -258,7 +306,7 @@ where } } }) - .unwrap(); + .expect("Failed to spawn terminal_event_handler thread"); active_threads.push(display_handler); active_threads.push(terminal_event_handler); @@ -268,9 +316,9 @@ where .into_iter() .map(|(iface, frames)| { let name = format!("sniffing_handler_{}", iface.name); - let running = running.clone(); + let running = Arc::clone(&running); let show_dns = opts.show_dns; - let network_utilization = network_utilization.clone(); + let network_utilization = Arc::clone(&network_utilization); thread::Builder::new() .name(name) @@ -279,16 +327,19 @@ where while running.load(Ordering::Acquire) { if let Some(segment) = sniffer.next() { - network_utilization.lock().unwrap().ingest(segment); + network_utilization + .lock() + .expect("network_utilization lock poisoned") + .ingest(segment); } } }) - .unwrap() + .expect("Failed to spawn sniffer thread") }) .collect::>(); active_threads.extend(sniffer_threads); for thread_handler in active_threads { - thread_handler.join().unwrap() + thread_handler.join().expect("Failed to join thread") } } diff --git a/src/network/dns/client.rs b/src/network/dns/client.rs index 95148172..400d0c5f 100644 --- a/src/network/dns/client.rs +++ b/src/network/dns/client.rs @@ -14,6 +14,8 @@ use crate::network::dns::{resolver::Lookup, IpTable}; type PendingAddrs = HashSet; +/// Size of the channel buffer for DNS resolution requests +/// Large enough to handle bursts of new connections without blocking const CHANNEL_SIZE: usize = 1_000; pub struct Client { @@ -48,9 +50,9 @@ impl Client { async move { if let Some(name) = resolver.lookup(ip).await { - cache.lock().unwrap().insert(ip, name); + cache.lock().expect("cache lock poisoned").insert(ip, name); } - pending.lock().unwrap().remove(&ip); + pending.lock().expect("pending lock poisoned").remove(&ip); } }); } @@ -71,17 +73,22 @@ impl Client { // Remove ips that are already being resolved let ips = ips .into_iter() - .filter(|ip| self.pending.lock().unwrap().insert(*ip)) + .filter(|ip| { + self.pending + .lock() + .expect("pending lock poisoned") + .insert(*ip) + }) .collect::>(); if !ips.is_empty() { // Discard the message if the channel is full; it will be retried eventually - let _ = self.tx.as_mut().unwrap().try_send(ips); + let _ = self.tx.as_mut().expect("tx should be Some").try_send(ips); } } pub fn cache(&mut self) -> IpTable { - let cache = self.cache.lock().unwrap(); + let cache = self.cache.lock().expect("cache lock poisoned"); cache.clone() } } @@ -89,7 +96,9 @@ impl Client { impl Drop for Client { fn drop(&mut self) { // Do the Option dance to be able to drop the sender so that the receiver finishes and the thread can be joined - drop(self.tx.take().unwrap()); - self.handle.take().unwrap().join().unwrap(); + drop(self.tx.take().expect("tx should be Some")); + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + } } } diff --git a/src/network/mod.rs b/src/network/mod.rs index 725a58b1..165dc4af 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1,3 +1,10 @@ +//! Network packet capture and analysis +//! +//! This module provides the core networking functionality for bandwhich: +//! - Packet sniffing from network interfaces +//! - Connection tracking and bandwidth utilization +//! - DNS resolution for IP addresses + mod connection; pub mod dns; mod sniffer; diff --git a/src/os/linux.rs b/src/os/linux.rs index 08a4aeb6..abb9cbd7 100644 --- a/src/os/linux.rs +++ b/src/os/linux.rs @@ -17,7 +17,12 @@ pub(crate) fn get_open_sockets() -> OpenSockets { let Ok(fds) = process.fd() else { continue }; let Ok(stat) = process.stat() else { continue }; let proc_name = stat.comm; - let proc_info = ProcessInfo::new(&proc_name, stat.pid as u32); + let parent_pid = if stat.ppid > 0 { + Some(stat.ppid as u32) + } else { + None + }; + let proc_info = ProcessInfo::with_parent(&proc_name, stat.pid as u32, parent_pid); for fd in fds.filter_map(|res| res.ok()) { if let FDTarget::Socket(inode) = fd.target { inode_to_proc.insert(inode, proc_info.clone()); diff --git a/src/os/lsof_utils.rs b/src/os/lsof_utils.rs index 4a61c675..0fedefcc 100644 --- a/src/os/lsof_utils.rs +++ b/src/os/lsof_utils.rs @@ -1,4 +1,7 @@ -use std::{ffi::OsStr, net::IpAddr, process::Command}; +use std::{collections::HashMap, ffi::OsStr, net::IpAddr, process::Command}; + +#[cfg(unix)] +use std::os::unix::process::ExitStatusExt; use log::warn; use once_cell::sync::Lazy; @@ -29,7 +32,7 @@ fn get_null_addr(ip_type: &str) -> &str { } impl RawConnection { - pub fn new(raw_line: &str) -> Option { + pub fn new(raw_line: &str, parent_pids: &HashMap) -> Option { // Example row // com.apple 664 user 198u IPv4 0xeb179a6650592b8d 0t0 TCP 192.168.1.187:58535->1.2.3.4:443 (ESTABLISHED) let columns: Vec<&str> = raw_line.split_ascii_whitespace().collect(); @@ -38,7 +41,8 @@ impl RawConnection { } let process_name = columns[0].replace("\\x20", " "); let pid = columns[1].parse().ok()?; - let proc_info = ProcessInfo::new(&process_name, pid); + let parent_pid = parent_pids.get(&pid).copied(); + let proc_info = ProcessInfo::with_parent(&process_name, pid, parent_pid); // Unneeded // let username = columns[2]; // let fd = columns[3]; @@ -139,9 +143,37 @@ impl RawConnection { } } +fn get_parent_pids() -> HashMap { + let output = Command::new("ps") + .args(["-eo", "pid,ppid"]) + .output() + .unwrap_or_else(|_| std::process::Output { + stdout: Vec::new(), + stderr: Vec::new(), + status: std::process::ExitStatus::from_raw(1), + }); + + let content = String::from_utf8_lossy(&output.stdout); + let mut parent_map = HashMap::new(); + + for line in content.lines().skip(1) { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 2 { + if let (Ok(pid), Ok(ppid)) = (parts[0].parse::(), parts[1].parse::()) { + if ppid > 0 { + parent_map.insert(pid, ppid); + } + } + } + } + + parent_map +} + pub fn get_connections() -> RawConnections { let content = run(["-n", "-P", "-i4", "-i6", "+c", "0"]); - RawConnections::new(content) + let parent_pids = get_parent_pids(); + RawConnections::new(content, parent_pids) } fn run(args: I) -> String @@ -162,8 +194,11 @@ pub struct RawConnections { } impl RawConnections { - pub fn new(content: String) -> RawConnections { - let lines: Vec = content.lines().flat_map(RawConnection::new).collect(); + pub fn new(content: String, parent_pids: HashMap) -> RawConnections { + let lines: Vec = content + .lines() + .filter_map(|line| RawConnection::new(line, &parent_pids)) + .collect(); RawConnections { content: lines } } @@ -192,7 +227,8 @@ com.apple 590 etoledom 204u IPv4 0x28ffb9c04111253f 0t0 TCP 192.168.1. #[test] fn test_iterator_multiline() { - let iterator = RawConnections::new(String::from(FULL_RAW_OUTPUT)); + let parent_pids = HashMap::new(); + let iterator = RawConnections::new(String::from(FULL_RAW_OUTPUT), parent_pids); let connections: Vec = iterator.collect(); assert_eq!(connections.len(), 4); } @@ -206,13 +242,15 @@ com.apple 590 etoledom 204u IPv4 0x28ffb9c04111253f 0t0 TCP 192.168.1. test_raw_connection_is_created_from_raw_output(IPV6_LINE_RAW_OUTPUT); } fn test_raw_connection_is_created_from_raw_output(raw_output: &str) { - let connection = RawConnection::new(raw_output); + let parent_pids = HashMap::new(); + let connection = RawConnection::new(raw_output, &parent_pids); assert!(connection.is_some()); } #[test] fn test_raw_connection_is_not_created_from_wrong_raw_output() { - let connection = RawConnection::new("not a process"); + let parent_pids = HashMap::new(); + let connection = RawConnection::new("not a process", &parent_pids); assert!(connection.is_none()); } @@ -225,7 +263,8 @@ com.apple 590 etoledom 204u IPv4 0x28ffb9c04111253f 0t0 TCP 192.168.1. test_raw_connection_parse_local_port(IPV6_LINE_RAW_OUTPUT); } fn test_raw_connection_parse_local_port(raw_output: &str) { - let connection = RawConnection::new(raw_output).unwrap(); + let parent_pids = HashMap::new(); + let connection = RawConnection::new(raw_output, &parent_pids).unwrap(); assert_eq!(connection.get_local_port(), Some(1111)); } @@ -238,7 +277,8 @@ com.apple 590 etoledom 204u IPv4 0x28ffb9c04111253f 0t0 TCP 192.168.1. test_raw_connection_parse_protocol(IPV6_LINE_RAW_OUTPUT); } fn test_raw_connection_parse_protocol(raw_line: &str) { - let connection = RawConnection::new(raw_line).unwrap(); + let parent_pids = HashMap::new(); + let connection = RawConnection::new(raw_line, &parent_pids).unwrap(); assert_eq!(connection.get_protocol(), Some(Protocol::Udp)); } @@ -251,7 +291,8 @@ com.apple 590 etoledom 204u IPv4 0x28ffb9c04111253f 0t0 TCP 192.168.1. test_raw_connection_parse_process_name(IPV6_LINE_RAW_OUTPUT); } fn test_raw_connection_parse_process_name(raw_line: &str) { - let connection = RawConnection::new(raw_line).unwrap(); + let parent_pids = HashMap::new(); + let connection = RawConnection::new(raw_line, &parent_pids).unwrap(); assert_eq!(connection.proc_info.name, String::from("ProcessName")); } } diff --git a/src/os/mod.rs b/src/os/mod.rs index 434c9d7c..e4ade740 100644 --- a/src/os/mod.rs +++ b/src/os/mod.rs @@ -1,3 +1,15 @@ +//! Operating system specific functionality +//! +//! This module provides platform-specific implementations for: +//! - Mapping network sockets to processes +//! - Terminal event handling +//! - Network interface discovery +//! +//! Supported platforms: +//! - Linux/Android: Uses /proc filesystem +//! - macOS/FreeBSD: Uses lsof command +//! - Windows: Uses Windows APIs + #[cfg(any(target_os = "android", target_os = "linux"))] mod linux; diff --git a/src/os/shared.rs b/src/os/shared.rs index 511ffc1b..b083ccbc 100644 --- a/src/os/shared.rs +++ b/src/os/shared.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, io::{self, ErrorKind, Write}, net::Ipv4Addr, time, @@ -24,6 +25,7 @@ use crate::os::windows::get_open_sockets; pub struct ProcessInfo { pub name: String, pub pid: u32, + pub parent_pid: Option, } impl ProcessInfo { @@ -31,8 +33,187 @@ impl ProcessInfo { Self { name: name.to_string(), pid, + parent_pid: None, } } + + pub fn with_parent(name: &str, pid: u32, parent_pid: Option) -> Self { + Self { + name: name.to_string(), + pid, + parent_pid, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct ProcessTreeNode { + pub process_info: ProcessInfo, + pub children: Vec, + #[allow(dead_code)] + pub depth: usize, +} + +impl ProcessTreeNode { + pub fn new(process_info: ProcessInfo, depth: usize) -> Self { + Self { + process_info, + children: Vec::new(), + depth, + } + } + + pub fn add_child(&mut self, child: ProcessTreeNode) { + self.children.push(child); + } + + #[allow(dead_code)] + pub fn find_node_mut(&mut self, pid: u32) -> Option<&mut ProcessTreeNode> { + if self.process_info.pid == pid { + return Some(self); + } + + for child in &mut self.children { + if let Some(node) = child.find_node_mut(pid) { + return Some(node); + } + } + + None + } + + pub fn iter_depth_first(&self) -> ProcessTreeIterator<'_> { + ProcessTreeIterator::new(self) + } +} + +pub struct ProcessTreeIterator<'a> { + stack: Vec<(&'a ProcessTreeNode, usize)>, +} + +impl<'a> ProcessTreeIterator<'a> { + fn new(root: &'a ProcessTreeNode) -> Self { + let stack = vec![(root, 0)]; + Self { stack } + } +} + +impl<'a> Iterator for ProcessTreeIterator<'a> { + type Item = (&'a ProcessInfo, usize); + + fn next(&mut self) -> Option { + if let Some((node, depth)) = self.stack.pop() { + // Add children to stack in reverse order for depth-first traversal + for child in node.children.iter().rev() { + self.stack.push((child, depth + 1)); + } + + Some((&node.process_info, depth)) + } else { + None + } + } +} + +pub fn build_process_trees(processes: Vec) -> Vec { + let mut process_map: HashMap = HashMap::new(); + let mut children_map: HashMap> = HashMap::new(); + let mut roots = Vec::new(); + + // First pass: build maps + for process in processes { + let pid = process.pid; + process_map.insert(pid, process.clone()); + + if let Some(parent_pid) = process.parent_pid { + children_map.entry(parent_pid).or_default().push(pid); + } else { + // Process without parent is a root + roots.push(pid); + } + } + + // Second pass: build trees + fn build_tree_recursive( + pid: u32, + process_map: &HashMap, + children_map: &HashMap>, + depth: usize, + ) -> Option { + let process_info = process_map.get(&pid)?.clone(); + let mut node = ProcessTreeNode::new(process_info, depth); + + if let Some(child_pids) = children_map.get(&pid) { + for &child_pid in child_pids { + if let Some(child_node) = + build_tree_recursive(child_pid, process_map, children_map, depth + 1) + { + node.add_child(child_node); + } + } + } + + Some(node) + } + + // Build root trees + let mut result = Vec::new(); + for root_pid in roots { + if let Some(tree) = build_tree_recursive(root_pid, &process_map, &children_map, 0) { + result.push(tree); + } + } + + // Handle orphaned processes (parent not in our process list) + for (pid, process_info) in &process_map { + if let Some(parent_pid) = process_info.parent_pid { + if !process_map.contains_key(&parent_pid) { + // Parent not found, treat as root + if let Some(tree) = build_tree_recursive(*pid, &process_map, &children_map, 0) { + result.push(tree); + } + } + } + } + + result +} + +pub fn aggregate_bandwidth_by_tree( + process_trees: &[ProcessTreeNode], + processes_map: &HashMap, +) -> HashMap { + use crate::display::{Bandwidth, NetworkData}; + + let mut aggregated = HashMap::new(); + + fn aggregate_recursive( + node: &ProcessTreeNode, + processes_map: &HashMap, + aggregated: &mut HashMap, + ) -> NetworkData { + // Get this process's own bandwidth data + let mut total_data = processes_map + .get(&node.process_info) + .cloned() + .unwrap_or_default(); + + // Aggregate data from all children + for child in &node.children { + let child_data = aggregate_recursive(child, processes_map, aggregated); + total_data.combine_bandwidth(&child_data); + } + + // Store the aggregated data for this process + aggregated.insert(node.process_info.clone(), total_data.clone()); + total_data + } + + for tree in process_trees { + aggregate_recursive(tree, processes_map, &mut aggregated); + } + + aggregated } pub struct TerminalEvents; @@ -241,3 +422,160 @@ fn eperm_message() -> &'static str { fn eperm_message() -> &'static str { "Insufficient permissions to listen on network interface(s). Try running with administrator rights." } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_process_info_new() { + let proc_info = ProcessInfo::new("test_process", 1234); + assert_eq!(proc_info.name, "test_process"); + assert_eq!(proc_info.pid, 1234); + assert_eq!(proc_info.parent_pid, None); + } + + #[test] + fn test_process_info_with_parent() { + let proc_info = ProcessInfo::with_parent("test_process", 1234, Some(5678)); + assert_eq!(proc_info.name, "test_process"); + assert_eq!(proc_info.pid, 1234); + assert_eq!(proc_info.parent_pid, Some(5678)); + } + + #[test] + fn test_build_process_trees_simple() { + let processes = vec![ + ProcessInfo::with_parent("parent", 1, None), + ProcessInfo::with_parent("child", 2, Some(1)), + ]; + + let trees = build_process_trees(processes); + assert_eq!(trees.len(), 1); + + let root = &trees[0]; + assert_eq!(root.process_info.name, "parent"); + assert_eq!(root.process_info.pid, 1); + assert_eq!(root.children.len(), 1); + + let child = &root.children[0]; + assert_eq!(child.process_info.name, "child"); + assert_eq!(child.process_info.pid, 2); + assert_eq!(child.children.len(), 0); + } + + #[test] + fn test_build_process_trees_multiple_roots() { + let processes = vec![ + ProcessInfo::with_parent("root1", 1, None), + ProcessInfo::with_parent("root2", 2, None), + ProcessInfo::with_parent("child1", 3, Some(1)), + ]; + + let trees = build_process_trees(processes); + assert_eq!(trees.len(), 2); + + // Find the root with children + let tree_with_child = trees.iter().find(|t| !t.children.is_empty()).unwrap(); + assert_eq!(tree_with_child.process_info.name, "root1"); + assert_eq!(tree_with_child.children.len(), 1); + assert_eq!(tree_with_child.children[0].process_info.name, "child1"); + + // Find the root without children + let tree_without_child = trees.iter().find(|t| t.children.is_empty()).unwrap(); + assert_eq!(tree_without_child.process_info.name, "root2"); + } + + #[test] + fn test_build_process_trees_orphaned_processes() { + let processes = vec![ + ProcessInfo::with_parent("orphan", 1, Some(999)), // Parent 999 doesn't exist + ProcessInfo::with_parent("child", 2, Some(1)), + ]; + + let trees = build_process_trees(processes); + assert_eq!(trees.len(), 1); + + let root = &trees[0]; + assert_eq!(root.process_info.name, "orphan"); + assert_eq!(root.children.len(), 1); + assert_eq!(root.children[0].process_info.name, "child"); + } + + #[test] + fn test_process_tree_iterator() { + let processes = vec![ + ProcessInfo::with_parent("root", 1, None), + ProcessInfo::with_parent("child1", 2, Some(1)), + ProcessInfo::with_parent("child2", 3, Some(1)), + ProcessInfo::with_parent("grandchild", 4, Some(2)), + ]; + + let trees = build_process_trees(processes); + assert_eq!(trees.len(), 1); + + let root = &trees[0]; + let items: Vec<_> = root.iter_depth_first().collect(); + + // Should visit root, then child1, then grandchild, then child2 + assert_eq!(items.len(), 4); + assert_eq!(items[0].0.name, "root"); + assert_eq!(items[0].1, 0); // depth 0 + assert_eq!(items[1].0.name, "child1"); + assert_eq!(items[1].1, 1); // depth 1 + assert_eq!(items[2].0.name, "grandchild"); + assert_eq!(items[2].1, 2); // depth 2 + assert_eq!(items[3].0.name, "child2"); + assert_eq!(items[3].1, 1); // depth 1 + } + + #[test] + fn test_aggregate_bandwidth_by_tree() { + use crate::display::NetworkData; + use std::collections::HashMap; + + // Create test processes + let processes = vec![ + ProcessInfo::with_parent("parent", 1, None), + ProcessInfo::with_parent("child", 2, Some(1)), + ]; + + // Create bandwidth data + let mut bandwidth_map = HashMap::new(); + bandwidth_map.insert( + ProcessInfo::with_parent("parent", 1, None), + NetworkData { + total_bytes_downloaded: 100, + total_bytes_uploaded: 50, + connection_count: 1, + }, + ); + bandwidth_map.insert( + ProcessInfo::with_parent("child", 2, Some(1)), + NetworkData { + total_bytes_downloaded: 200, + total_bytes_uploaded: 100, + connection_count: 2, + }, + ); + + let trees = build_process_trees(processes); + let aggregated = aggregate_bandwidth_by_tree(&trees, &bandwidth_map); + + // Parent should have its own bandwidth + child's bandwidth + let parent_data = aggregated + .get(&ProcessInfo::with_parent("parent", 1, None)) + .unwrap(); + assert_eq!(parent_data.total_bytes_downloaded, 300); // 100 + 200 + assert_eq!(parent_data.total_bytes_uploaded, 150); // 50 + 100 + assert_eq!(parent_data.connection_count, 2); // child's count (combined_bandwidth doesn't add connection_count for parent) + + // Child should have only its own bandwidth + let child_data = aggregated + .get(&ProcessInfo::with_parent("child", 2, Some(1))) + .unwrap(); + assert_eq!(child_data.total_bytes_downloaded, 200); + assert_eq!(child_data.total_bytes_uploaded, 100); + assert_eq!(child_data.connection_count, 2); + } +} diff --git a/src/os/windows.rs b/src/os/windows.rs index 65a765d6..f373208d 100644 --- a/src/os/windows.rs +++ b/src/os/windows.rs @@ -25,7 +25,14 @@ pub(crate) fn get_open_sockets() -> OpenSockets { .associated_pids .into_iter() .find_map(|pid| sysinfo.process(Pid::from_u32(pid))) - .map(|p| ProcessInfo::new(&p.name().to_string_lossy(), p.pid().as_u32())) + .map(|p| { + let parent_pid = p.parent().map(|ppid| ppid.as_u32()); + ProcessInfo::with_parent( + &p.name().to_string_lossy(), + p.pid().as_u32(), + parent_pid, + ) + }) .unwrap_or_default(); match si.protocol_socket_info {