Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions probing/cli/src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,26 @@ pub enum Commands {
#[command(subcommand = false, hide = true)]
Store(StoreCommand),
}

impl Commands {
/// Determines whether this command should have a timeout applied.
/// Long-running commands like Launch and External should not time out.
pub fn is_timed_command(&self) -> bool {
match self {
// Long-running commands - no timeout
Commands::Repl => false,
Commands::Launch { .. } => false,
Commands::External(_) => false,
// Short-running commands - apply timeout
Commands::List { .. } => true,
Commands::Backtrace { .. } => true,
Commands::Rdma { .. } => true,
Commands::Eval { .. } => true,
Commands::Query { .. } => true,
Commands::Store(_) => true,
Commands::Config { .. } => true,
#[cfg(target_os = "linux")]
Commands::Inject(_) => true,
}
}
}
3 changes: 3 additions & 0 deletions probing/cli/src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@
#[command(subcommand)]
command: Option<Commands>,
}

Check warning on line 57 in probing/cli/src/cli/mod.rs

View workflow job for this annotation

GitHub Actions / Build Package and Run Python Tests (Linux)

Diff in /home/runner/work/probing/probing/probing/cli/src/cli/mod.rs
impl Cli {
pub fn should_timeout(&self) -> bool {
self.command.as_ref().map_or(true, |cmd| cmd.is_timed_command())
}
pub async fn run(&mut self) -> Result<()> {
// Handle external commands first to avoid target requirement
if let Some(Commands::External(args)) = &self.command {
Expand Down
15 changes: 13 additions & 2 deletions probing/cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
use anyhow::Result;
use clap::Parser;
use env_logger::Env;
use std::time::Duration;
use tokio::time::timeout;

const ENV_PROBING_LOGLEVEL: &str = "PROBING_LOGLEVEL";

/// Main entry point for the CLI, can be called from Python or as a binary
#[tokio::main]
pub async fn cli_main(args: Vec<String>) -> Result<()> {

Check warning on line 16 in probing/cli/src/lib.rs

View workflow job for this annotation

GitHub Actions / Build Package and Run Python Tests (Linux)

Diff in /home/runner/work/probing/probing/probing/cli/src/lib.rs
let _ = env_logger::try_init_from_env(Env::new().filter(ENV_PROBING_LOGLEVEL));
cli::Cli::parse_from(args).run().await

let mut cli = cli::Cli::parse_from(args);

if cli.should_timeout() {
match timeout(Duration::from_secs(10), cli.run()).await {
Ok(result) => result,

Check warning on line 23 in probing/cli/src/lib.rs

View workflow job for this annotation

GitHub Actions / Build Package and Run Python Tests (Linux)

Diff in /home/runner/work/probing/probing/probing/cli/src/lib.rs
Err(_) => Err(anyhow::anyhow!("Cli Command Timeout reached")),
}
} else {
cli.run().await
}
}
5 changes: 3 additions & 2 deletions probing/cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use anyhow::Result;
use probing_cli::cli_main;

fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
// cli_main already uses #[tokio::main], so it handles async execution internally
cli_main(args)
cli_main(args).await
}
7 changes: 6 additions & 1 deletion probing/extensions/python/src/features/python_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ pub fn query_json(_py: Python, sql: String) -> PyResult<String> {

#[pyfunction]
pub fn cli_main(_py: Python, args: Vec<String>) -> PyResult<()> {
if let Err(e) = cli_main_impl(args) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;

if let Err(e) = runtime.block_on(cli_main_impl(args)) {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
e.to_string(),
));
Expand Down
Loading