diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 17ff717b..baddae33 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "srpc_client" -version = "0.1.1" +version = "0.2.0" edition = "2021" [lib] @@ -8,29 +8,48 @@ name = "srpc_client" crate-type = ["cdylib", "rlib"] [dependencies] -tokio = { version = "1.0", features = ["full"] } +async-trait = "0.1" +bytes = "1" +futures = "0.3" +tokio = { version = "1", features = ["full"] } openssl = "0.10" -serde_json = "1.0" +serde = {version = "1", features = ["derive"]} +serde_json = "1" tokio-openssl = "0.6" +tokio-util = { version = "0.7", features = ["codec"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } [dependencies.pyo3] -version = "0.18" +version = "0.23" features = ["extension-module"] optional = true -[dependencies.pyo3-asyncio] -version = "0.18" -features = ["tokio-runtime"] +[dependencies.pyo3-async-runtimes] +version = "0.23" +features = ["attributes", "tokio-runtime"] optional = true [features] default = [] -python = ["pyo3", "pyo3-asyncio"] +python = ["pyo3", "pyo3-async-runtimes"] [[example]] name = "rust_client_example" path = "examples/rust_client_example.rs" required-features = [] +[[example]] +name = "rust_client_example2" +path = "examples/rust_client_example2.rs" +required-features = [] + +[[example]] +name = "rust_client_example3" +path = "examples/rust_client_example3.rs" +required-features = [] + [dev-dependencies] -tokio = { version = "1.0", features = ["full", "macros"] } +rstest = "0.23.0" +test-log = { version = "0.2.16", features = ["trace", "color"] } +tokio = { version = "1", features = ["full", "macros", "test-util"] } diff --git a/rust/lib/srpc/client/examples/python_client_example.py b/rust/lib/srpc/client/examples/python_client_example.py index 11bc6b1c..52f1fa52 100644 --- a/rust/lib/srpc/client/examples/python_client_example.py +++ b/rust/lib/srpc/client/examples/python_client_example.py @@ -2,19 +2,33 @@ This example demonstrates how to use the srpc_client Python bindings. To run this example: -1. Build the Rust library: maturin build --features python -2. Install the wheel: pip install target/wheels/srpc_client-*.whl -3. Run this script: python examples/python_client_example.py +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_1_SRPC_SERVER_HOST= \ + EXAMPLE_1_SRPC_SERVER_PORT= \ + EXAMPLE_1_SRPC_SERVER_ENPOINT= \ + EXAMPLE_1_SRPC_SERVER_CERT= \ + EXAMPLE_1_SRPC_SERVER_KEY= \ + python examples/python_client_example.py """ import asyncio import json -from srpc_client import SrpcClient +import os +from srpc_client import SrpcClientConfig + async def main(): - client = SrpcClient("", 6976, "/_SRPC_/TLS/JSON", "", "") - - await client.connect() + client = SrpcClientConfig( + os.environ["EXAMPLE_1_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_1_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_1_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_1_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_1_SRPC_SERVER_KEY"], + ) + + client = await client.connect() print("Connected to server") message = "Hypervisor.StartVm\n" @@ -25,9 +39,7 @@ async def main(): for response in responses: print(f"Received response: {response}") - json_payload = { - "IpAddress": "" - } + json_payload = {"IpAddress": ""} json_string = json.dumps(json_payload) await client.send_json(json_string) print(f"Sent JSON payload: {json_payload}") @@ -36,5 +48,6 @@ async def main(): for json_response in json_responses: print(f"Received JSON response: {json.loads(json_response)}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/python_client_example2.py b/rust/lib/srpc/client/examples/python_client_example2.py new file mode 100644 index 00000000..ea271912 --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example2.py @@ -0,0 +1,57 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_2_SRPC_SERVER_HOST= \ + EXAMPLE_2_SRPC_SERVER_PORT= \ + EXAMPLE_2_SRPC_SERVER_ENPOINT= \ + EXAMPLE_2_SRPC_SERVER_CERT= \ + EXAMPLE_2_SRPC_SERVER_KEY= \ + python examples/python_client_example2.py +""" + +import asyncio +import json +import os +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + os.environ["EXAMPLE_2_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_2_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_2_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_2_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_2_SRPC_SERVER_KEY"], + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + # Send a message + message = "Hypervisor.GetUpdates\n" + print(f"Sending message: {message}") + await client.send_message(message) + print(f"Sent message: {message}") + + # Receive an empty response + print("Waiting for empty string response...") + responses = await client.receive_message(expect_empty=True, should_continue=False) + async for response in responses: + print(f"Received response: {response}") + + # Receive responses + responses = await client.receive_json_cb(should_continue=lambda _: True) + async for response in responses: + print(f"Received response: {json.loads(response)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/python_client_example3.py b/rust/lib/srpc/client/examples/python_client_example3.py new file mode 100644 index 00000000..0c80bf5a --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example3.py @@ -0,0 +1,95 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_3_SRPC_SERVER_HOST= \ + EXAMPLE_3_SRPC_SERVER_PORT= \ + EXAMPLE_3_SRPC_SERVER_ENPOINT= \ + EXAMPLE_3_SRPC_SERVER_CERT= \ + EXAMPLE_3_SRPC_SERVER_KEY= \ + python examples/python_client_example3.py +""" + +import asyncio +import json +import os +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + os.environ["EXAMPLE_3_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_3_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_3_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_3_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_3_SRPC_SERVER_KEY"], + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + message = "Hypervisor.ListVMs\n" + + # Send a message + print(f"Sending message: {message}") + await client.send_message(message) + print(f"Sent message: {message}") + + # Receive an empty response + print("Waiting for empty string response...") + responses = await client.receive_message(expect_empty=True, should_continue=False) + async for response in responses: + print(f"Received response: {response}") + + # Send a JSON message + payload = json.dumps( + { + "IgnoreStateMask": 0, + "OwnerGroups": [], + "OwnerUsers": [], + "Sort": True, + "VmTagsToMatch": {}, + } + ) + print(f"Sending payload: {payload}") + await client.send_json(payload) + print(f"Sent payload: {payload}") + + # Receive an empty response + print("Waiting for empty string response for payload...") + responses = await client.receive_message(expect_empty=True, should_continue=False) + async for response in responses: + print(f"Received response: {response}") + + # Receive responses + print("Waiting for response...") + responses = await client.receive_json_cb(should_continue=lambda _: False) + async for response in responses: + print(f"Received response: {json.loads(response)}") + + # Use RequestReply + print(f"Sending request_reply: {message}") + res = await client.request_reply( + message, + json.dumps( + { + "IgnoreStateMask": 0, + "OwnerGroups": [], + "OwnerUsers": [], + "Sort": True, + "VmTagsToMatch": {}, + } + ), + ) + print(f"Sent request_reply: {message}, got reply: {res}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/python_client_example4.py b/rust/lib/srpc/client/examples/python_client_example4.py new file mode 100644 index 00000000..a516c69a --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example4.py @@ -0,0 +1,54 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_4_SRPC_SERVER_HOST= \ + EXAMPLE_4_SRPC_SERVER_PORT= \ + EXAMPLE_4_SRPC_SERVER_ENPOINT= \ + EXAMPLE_4_SRPC_SERVER_CERT= \ + EXAMPLE_4_SRPC_SERVER_KEY= \ + python examples/python_client_example4.py +""" + +import asyncio +import json +import os +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + os.environ["EXAMPLE_4_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_4_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_4_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_4_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_4_SRPC_SERVER_KEY"], + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + # Send a message + message = "Hypervisor.GetUpdates\n" + print(f"Calling server with message: {message}") + conn = await client.call(message) + response = await conn.decode() + print(f"Received response: {json.loads(response)}") + await conn.close() + + print(f"Calling server with message again: {message}") + conn2 = await client.call(message) + response = await conn2.decode() + print(f"Received response: {json.loads(response)}") + await conn2.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 2e723dbb..f0499bc2 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,34 +1,57 @@ -use srpc_client::Client; -use tokio; +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_1_SRPC_SERVER_HOST= \ + EXAMPLE_1_SRPC_SERVER_PORT= \ + EXAMPLE_1_SRPC_SERVER_ENPOINT= \ + EXAMPLE_1_SRPC_SERVER_CERT= \ + EXAMPLE_1_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example +**/ use serde_json::json; +use srpc_client::{ClientConfig, ReceiveOptions}; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + // Create a new Client instance - let client = Client::new( - "", - 6976, - "/_SRPC_/TLS/JSON", - "", - "" + let client = ClientConfig::new( + &std::env::var("EXAMPLE_1_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_1_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_KEY")?, ); // Connect to the server - client.connect().await?; - println!("Connected to server"); + let client = client.connect().await?; + info!("Connected to server"); // Send a message let message = "Hypervisor.ProbeVmPort\n"; - println!("Sending message: {:?}", message); + info!("Sending message: {:?}", message); client.send_message(message).await?; // Receive an empty response - println!("Waiting for empty string response..."); - let mut rx = client.receive_message(true, |_| false).await?; + info!("Waiting for empty string response..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { - Ok(response) => println!("Received response: {:?}", response), - Err(e) => eprintln!("Error receiving message: {:?}", e), + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), } } @@ -38,16 +61,18 @@ async fn main() -> Result<(), Box> { "PortNumber": 22 }); - println!("Sending JSON payload: {:?}", json_payload); + info!("Sending JSON payload: {:?}", json_payload); client.send_json(&json_payload).await?; // Receive and parse JSON response - println!("Waiting for JSON response..."); - let mut rx = client.receive_json(|_| false).await?; + info!("Waiting for JSON response..."); + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { - Ok(json_response) => println!("Received JSON response: {:?}", json_response), - Err(e) => eprintln!("Error receiving JSON: {:?}", e), + Ok(json_response) => info!("Received JSON response: {:?}", json_response), + Err(e) => error!("Error receiving JSON: {:?}", e), } } diff --git a/rust/lib/srpc/client/examples/rust_client_example2.rs b/rust/lib/srpc/client/examples/rust_client_example2.rs new file mode 100644 index 00000000..cfc6e788 --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example2.rs @@ -0,0 +1,70 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_2_SRPC_SERVER_HOST= \ + EXAMPLE_2_SRPC_SERVER_PORT= \ + EXAMPLE_2_SRPC_SERVER_ENPOINT= \ + EXAMPLE_2_SRPC_SERVER_CERT= \ + EXAMPLE_2_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example2 +**/ +use srpc_client::{ClientConfig, ReceiveOptions}; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + &std::env::var("EXAMPLE_2_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_2_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_KEY")?, + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + // Send a message + let message = "Hypervisor.GetUpdates\n"; + info!("Sending message: {:?}", message); + client.send_message(message).await?; + info!("Sent message: {:?}", message); + + // Receive an empty response + info!("Waiting for empty string response..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + // Receive responses + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + Ok(()) +} diff --git a/rust/lib/srpc/client/examples/rust_client_example3.rs b/rust/lib/srpc/client/examples/rust_client_example3.rs new file mode 100644 index 00000000..afbeba9e --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example3.rs @@ -0,0 +1,129 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_3_SRPC_SERVER_HOST= \ + EXAMPLE_3_SRPC_SERVER_PORT= \ + EXAMPLE_3_SRPC_SERVER_ENPOINT= \ + EXAMPLE_3_SRPC_SERVER_CERT= \ + EXAMPLE_3_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example3 +**/ +use std::{collections::HashMap, error::Error}; + +use srpc_client::{ClientConfig, CustomError, ReceiveOptions, SimpleValue}; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + &std::env::var("EXAMPLE_3_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_3_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_KEY")?, + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + let message = "Hypervisor.ListVMs\n"; + + // Send a message + info!("Sending message: {:?}", message); + client.send_message(message).await?; + info!("Sent message: {:?}", message); + + // Receive an empty response + info!("Waiting for empty string response..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + #[derive(Debug, serde::Serialize)] + struct ListVMsRequest { + ignore_state_mask: u32, + owner_groups: Vec, + owner_users: Vec, + sort: bool, + vm_tags_to_match: HashMap, + } + + #[derive(Debug, serde::Deserialize)] + struct ListVMsResponse { + ip_addresses: Vec, + } + + let request = ListVMsRequest { + ignore_state_mask: 0, + owner_groups: vec![], + owner_users: vec![], + sort: false, + vm_tags_to_match: HashMap::new(), + }; + + // Send a JSON message + info!("Sending payload: {:?}", request); + client.send_json(&serde_json::to_value(&request)?).await?; + info!("Sent payload: {:?}", request); + + // Receive an empty response + info!("Waiting for empty string response for payload..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + // Receive responses + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result + .and_then(|response| { + serde_json::from_value::(response) + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + }) + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + info!("Sending request_reply: {}", message); + let res = client + .request_reply::(message, serde_json::to_value(&request)?) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!( + "Sent request_reply: {}, got reply: {:?}", + message, + serde_json::to_string(&res)? + ); + + Ok(()) +} diff --git a/rust/lib/srpc/client/examples/rust_client_example4.rs b/rust/lib/srpc/client/examples/rust_client_example4.rs new file mode 100644 index 00000000..cc50a463 --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example4.rs @@ -0,0 +1,71 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_4_SRPC_SERVER_HOST= \ + EXAMPLE_4_SRPC_SERVER_PORT= \ + EXAMPLE_4_SRPC_SERVER_ENPOINT= \ + EXAMPLE_4_SRPC_SERVER_CERT= \ + EXAMPLE_4_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example4 +**/ +use std::{error::Error, sync::Arc}; + +use srpc_client::{ClientConfig, CustomError}; +use tokio::sync::Mutex; +use tracing::{info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + &std::env::var("EXAMPLE_4_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_4_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_KEY")?, + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + let message = "Hypervisor.GetUpdates\n"; + + let safe_client = Arc::new(Mutex::new(client)); + let guard = safe_client.lock_owned().await; + info!("Calling server with message: {:?}", message); + let mut conn = srpc_client::call(guard, message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let val = conn + .decode() + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!("Received response: {:?}", val); + + let guard = conn.close(); + + info!("Calling server with message again: {:?}", message); + let mut conn2 = srpc_client::call(guard, message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let val = conn2 + .decode() + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!("Received response: {:?}", val); + let _guard = conn2.close(); + + Ok(()) +} diff --git a/rust/lib/srpc/client/src/chunk_limiter.rs b/rust/lib/srpc/client/src/chunk_limiter.rs new file mode 100644 index 00000000..d03a6527 --- /dev/null +++ b/rust/lib/srpc/client/src/chunk_limiter.rs @@ -0,0 +1,39 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tracing::trace; + +pub struct ChunkLimiter { + inner: R, + max_chunk_size: usize, +} + +impl ChunkLimiter { + pub fn new(inner: R, max_chunk_size: usize) -> Self { + Self { + inner, + max_chunk_size, + } + } +} + +impl AsyncRead for ChunkLimiter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let limit = self.max_chunk_size.min(buf.remaining()); + + let available = buf.initialize_unfilled_to(limit); + let mut limited_buf = ReadBuf::new(available); + + let poll_result = Pin::new(&mut self.inner).poll_read(cx, &mut limited_buf); + + let filled_len = limited_buf.filled().len(); + trace!("Read {} bytes", filled_len); + buf.advance(filled_len); + + poll_result + } +} diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 46af1760..36b7f1a7 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -1,18 +1,28 @@ -use tokio::net::TcpStream; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use async_trait::async_trait; +use chunk_limiter::ChunkLimiter; +use futures::StreamExt; +use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; +use serde_json::Value; +use std::borrow::BorrowMut; use std::error::Error; use std::fmt; -use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode, Ssl}; -use serde_json::Value; -use tokio_openssl::SslStream; -use tokio::time::{timeout, Duration}; -use std::sync::Arc; -use tokio::sync::{Mutex, mpsc}; use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, Mutex, OwnedMutexGuard}; +use tokio::time::{timeout, Duration}; +use tokio_openssl::SslStream; +use tokio_util::codec::{FramedRead, LinesCodec}; +use tracing::debug; + +mod chunk_limiter; +#[cfg(test)] +mod tests; // Custom error type #[derive(Debug)] -struct CustomError(String); +pub struct CustomError(pub String); impl fmt::Display for CustomError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -22,45 +32,213 @@ impl fmt::Display for CustomError { impl Error for CustomError {} -pub struct Client { +#[derive(Clone)] +pub struct ClientConfig { host: String, port: u16, path: String, cert: String, key: String, - stream: Arc>>>, } -impl Client { +#[cfg_attr(feature = "python", derive(FromPyObject))] +pub struct ReceiveOptions { + channel_buffer_size: usize, + max_chunk_size: usize, + read_next_line_duration: Duration, + should_continue_on_timeout: bool, +} + +impl ReceiveOptions { + pub fn new( + channel_buffer_size: usize, + max_chunk_size: usize, + read_next_line_duration: Duration, + should_continue_on_timeout: bool, + ) -> Self { + ReceiveOptions { + channel_buffer_size, + max_chunk_size, + read_next_line_duration, + should_continue_on_timeout, + } + } +} + +impl Default for ReceiveOptions { + fn default() -> Self { + ReceiveOptions { + channel_buffer_size: 100, + max_chunk_size: 16384, + read_next_line_duration: Duration::from_secs(10), + should_continue_on_timeout: true, + } + } +} + +pub struct ConnectedClient +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub connection_params: ClientConfig, + stream: Arc>, +} + +#[async_trait] +pub trait RequestReply { + type Request; + type Reply; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static; +} + +pub struct SimpleValue; + +#[async_trait] +impl RequestReply for SimpleValue { + type Request = serde_json::Value; + type Reply = serde_json::Value; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + client.send_json_and_check(&payload).await?; + + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let json_value = rx.recv().await.ok_or_else(|| { + Box::new(CustomError("Expected JSON value".to_string())) as Box + })??; + Ok(json_value) + } +} + +pub struct StreamValue; + +#[async_trait] +impl RequestReply for StreamValue { + type Request = serde_json::Value; + type Reply = mpsc::Receiver>>; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + client.send_json_and_check(&payload).await?; + + let rx = client + .receive_json(|_| true, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + Ok(rx) + } +} + +pub struct Conn +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + guard: Option>>, + rx: mpsc::Receiver>>, +} + +impl Conn +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub async fn new( + guard: tokio::sync::OwnedMutexGuard>, + ) -> Result> { + let rx = guard + .receive_json(|_| true, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + + Ok(Conn { + guard: Some(guard), + rx, + }) + } + + pub async fn encode(&self, message: Value) -> Result<(), Box> { + self.guard + .as_ref() + .unwrap() + .send_json(&message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + } + + pub async fn decode(&mut self) -> Result> { + let rx = self.rx.borrow_mut(); + let json_value = rx.recv().await.ok_or_else(|| { + Box::new(CustomError("Expected JSON value".to_string())) as Box + })??; + Ok(json_value) + } + + pub fn close(&mut self) -> OwnedMutexGuard> { + self.guard.take().unwrap() + } +} + +pub async fn call( + client: OwnedMutexGuard>, + method: &str, +) -> Result, Box> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + client.send_message_and_check(method).await?; + Conn::new(client).await +} + +impl ClientConfig { pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { - Client { + ClientConfig { host: host.to_string(), port, path: path.to_string(), cert: cert.to_string(), key: key.to_string(), - stream: Arc::new(Mutex::new(None)), } } - pub async fn connect(&self) -> Result<(), Box> { - // println!("Attempting to connect to {}:{}...", self.host, self.port); - + pub async fn connect(self) -> Result>, Box> { + debug!("Attempting to connect to {}:{}...", self.host, self.port); + let connect_timeout = Duration::from_secs(10); - let tcp_stream = match timeout(connect_timeout, - TcpStream::connect(format!("{}:{}", self.host, self.port)) - ).await { + let tcp_stream = match timeout( + connect_timeout, + TcpStream::connect(format!("{}:{}", self.host, self.port)), + ) + .await + { Ok(Ok(stream)) => stream, Ok(Err(e)) => return Err(format!("Failed to connect: {}", e).into()), Err(_) => return Err("Connection attempt timed out".into()), }; - // println!("TCP connection established"); - - // println!("Performing HTTP CONNECT..."); + debug!("TCP connection established"); + + debug!("Performing HTTP CONNECT..."); self.do_http_connect(&tcp_stream).await?; - // println!("HTTP CONNECT successful"); - - // println!("Starting TLS handshake..."); + debug!("HTTP CONNECT successful"); + + debug!("Starting TLS handshake..."); let mut connector = SslConnector::builder(SslMethod::tls())?; connector.set_verify(SslVerifyMode::NONE); @@ -68,31 +246,29 @@ impl Client { connector.set_certificate_file(&self.cert, openssl::ssl::SslFiletype::PEM)?; connector.set_private_key_file(&self.key, openssl::ssl::SslFiletype::PEM)?; } - + let ssl = Ssl::new(connector.build().context())?; let mut stream = SslStream::new(ssl, tcp_stream)?; - - // println!("Performing TLS handshake..."); + + debug!("Performing TLS handshake..."); Pin::new(&mut stream).connect().await?; - // println!("TLS handshake completed"); - - let mut lock = self.stream.lock().await; - *lock = Some(stream); - // println!("Connection fully established"); - - Ok(()) + debug!("TLS handshake completed"); + + debug!("Connection fully established"); + + Ok(ConnectedClient::new(self, stream)) } async fn do_http_connect(&self, stream: &TcpStream) -> Result<(), Box> { let connect_request = format!("CONNECT {} HTTP/1.0\r\n\r\n", self.path); - // println!("Sending HTTP CONNECT request: {:?}", connect_request); + debug!("Sending HTTP CONNECT request: {:?}", connect_request); stream.try_write(connect_request.as_bytes())?; - // println!("HTTP CONNECT request sent"); - + debug!("HTTP CONNECT request sent"); + let read_timeout = Duration::from_secs(10); let start_time = std::time::Instant::now(); let mut buffer = Vec::new(); - + while start_time.elapsed() < read_timeout { match stream.try_read_buf(&mut buffer) { Ok(0) => { @@ -112,79 +288,163 @@ impl Client { Err(e) => return Err(format!("Error reading HTTP CONNECT response: {}", e).into()), } } - + if buffer.is_empty() { return Err("Timeout while waiting for HTTP CONNECT response".into()); } - + let response = String::from_utf8_lossy(&buffer); - // println!("Received HTTP CONNECT response: {:?}", response); + debug!("Received HTTP CONNECT response: {:?}", response); if response.starts_with("HTTP/1.0 200") || response.starts_with("HTTP/1.1 200") { - // println!("HTTP CONNECT completed successfully"); + debug!("HTTP CONNECT completed successfully"); Ok(()) } else { Err(format!("Unexpected HTTP response: {}", response).into()) } } +} + +impl ConnectedClient +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub fn new(connection_params: ClientConfig, stream: T) -> Self { + ConnectedClient { + connection_params, + stream: Arc::new(Mutex::new(stream)), + } + } + + pub async fn request_reply( + &self, + method: &str, + payload: R::Request, + ) -> Result> + where + R: RequestReply, + { + self.send_message(method) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { + } else { + return Err(Box::new(CustomError("Expected empty line".to_string()))); + } + + R::request_reply(self, payload).await + } + + // pub async fn call( + // self, + // method: &str, + // ) -> Result, Box> + // { + // self.send_message_and_check(method).await?; + // Conn::new(self.stream.lock_owned().await).await + // } pub async fn send_message(&self, message: &str) -> Result<(), Box> { - let mut lock = self.stream.lock().await; - if let Some(stream) = lock.as_mut() { - let mut pinned = Pin::new(stream); - pinned.as_mut().write_all(message.as_bytes()).await?; - pinned.as_mut().flush().await?; - Ok(()) + let stream = self.stream.lock().await; + let mut pinned = Pin::new(stream); + pinned.as_mut().write_all(message.as_bytes()).await?; + pinned.as_mut().flush().await?; + Ok(()) + } + + pub async fn send_message_and_check(&self, message: &str) -> Result<(), Box> { + self.send_message(message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { } else { - Err("Not connected".into()) + return Err(Box::new(CustomError("Expected empty line".to_string()))); } + + Ok(()) } - pub async fn receive_message(&self, expect_empty: bool, mut should_continue: F) -> Result>>, Box> + pub async fn receive_message( + &self, + expect_empty: bool, + mut should_continue: F, + opts: &ReceiveOptions, + ) -> Result>>, Box> where F: FnMut(&str) -> bool + Send + 'static, { - let stream_clone = self.stream.clone(); - let (tx, rx) = mpsc::channel(100); + let stream = Arc::clone(&self.stream); + let (tx, rx) = mpsc::channel(opts.channel_buffer_size); + let max_chunk_size = if expect_empty { 1 } else { opts.max_chunk_size }; + let read_next_line_duration = opts.read_next_line_duration; + let should_continue_on_timeout = opts.should_continue_on_timeout; tokio::spawn(async move { + let mut guard = stream.lock().await; + let limited_reader = ChunkLimiter::new(&mut *guard, max_chunk_size); + let buf_reader = BufReader::new(limited_reader); + let mut framed = FramedRead::new(buf_reader, LinesCodec::new()); + loop { - let mut lock = stream_clone.lock().await; - if let Some(stream) = lock.as_mut() { - let mut response = String::new(); - loop { - let mut buf = [0; 1024]; - match stream.read(&mut buf).await { - Ok(0) => { - let _ = tx.send(Ok(String::new())).await; - return; - } - Ok(n) => { - response.push_str(&String::from_utf8_lossy(&buf[..n])); - if response.ends_with('\n') { + let result = timeout(read_next_line_duration, framed.next()).await; + match result { + Ok(Some(line_res)) => { + let line_res = line_res.map_err(|e| Box::new(e) as Box); + + match line_res { + Ok(line) => { + if expect_empty && !line.is_empty() { + let _ = tx + .send(Err(Box::new(CustomError(format!( + "Expected empty line, got: {:?}", + line + ))) + as Box)) + .await; + break; + } + + let _ = tx.send(Ok(line.clone())).await; + + if !should_continue(&line) { break; } } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; - return; + Err(err) => { + let _ = tx.send(Err(err)).await; + break; } } } - let response = response.trim().to_string(); - - if expect_empty && !response.is_empty() { - let _ = tx.send(Err(Box::new(CustomError(format!("Expected empty string, got: {:?}", response))) as Box)).await; - return; - } - - let _ = tx.send(Ok(response.clone())).await; - - if !should_continue(&response) { + Ok(None) => { break; } - } else { - let _ = tx.send(Err(Box::new(CustomError("Not connected".to_string())) as Box)).await; - return; + Err(_) => { + if !should_continue_on_timeout { + break; + } + } } } }); @@ -197,26 +457,49 @@ impl Client { self.send_message(&json_string).await } - pub async fn receive_json(&self, should_continue: F) -> Result>>, Box> + pub async fn send_json_and_check(&self, payload: &Value) -> Result<(), Box> { + self.send_json(payload) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { + } else { + return Err(Box::new(CustomError("Expected empty line".to_string()))); + } + + Ok(()) + } + + pub async fn receive_json( + &self, + should_continue: F, + opts: &ReceiveOptions, + ) -> Result>>, Box> where F: FnMut(&str) -> bool + Send + 'static, { - let mut rx = self.receive_message(false, should_continue).await?; - let (tx, new_rx) = mpsc::channel(100); + let mut rx = self.receive_message(false, should_continue, opts).await?; + let (tx, new_rx) = mpsc::channel(opts.channel_buffer_size); tokio::spawn(async move { while let Some(result) = rx.recv().await { - match result { - Ok(json_str) => { - match serde_json::from_str(&json_str) { - Ok(json_value) => { - if let Err(_) = tx.send(Ok(json_value)).await { - break; - } - } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; - } + match result.and_then(|json_str| { + serde_json::from_str(&json_str) + .map_err(|e| Box::new(e) as Box) + }) { + Ok(json_value) => { + if let Err(_) = tx.send(Ok(json_value)).await { + break; } } Err(e) => { @@ -228,7 +511,6 @@ impl Client { Ok(new_rx) } - } #[cfg(feature = "python")] @@ -239,7 +521,21 @@ use pyo3::prelude::*; #[cfg(feature = "python")] #[pymodule] -fn srpc_client(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; +fn srpc_client(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + use tracing::level_filters::LevelFilter; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 9c867db2..cfffadd2 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,74 +1,372 @@ -use crate::Client; +use crate::{ClientConfig, Conn, ConnectedClient, ReceiveOptions, SimpleValue}; +use futures::{Stream, StreamExt}; +use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; -use pyo3::exceptions::PyRuntimeError; -use pyo3_asyncio; +use pyo3::types::PyFunction; use serde_json::Value; -use std::sync::Arc; -use tokio::sync::Mutex; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, Mutex}; +use tokio_openssl::SslStream; #[pyclass] -pub struct SrpcClient(Arc>); +pub struct SrpcClientConfig(ClientConfig); + +#[pyclass] +pub struct ConnectedSrpcClient(Arc>>>); + +#[pyclass] +pub struct SrpcMethodCallConn(Arc>>>); #[pymethods] -impl SrpcClient { +impl SrpcClientConfig { #[new] pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { - SrpcClient(Arc::new(Mutex::new(Client::new(host, port, path, cert, key)))) + SrpcClientConfig(ClientConfig::new(host, port, path, cert, key)) + } + + pub fn connect<'p>(&self, py: Python<'p>) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + client + .connect() + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + .map(|c| { + ConnectedSrpcClient(Arc::new(Mutex::new(c))) + }) + }) + } +} + +struct Streamer { + rx: mpsc::Receiver>>, +} + +impl Streamer { + fn new(rx: mpsc::Receiver>>) -> Self { + Streamer { rx } + } +} + +impl Stream for Streamer { + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +#[pyo3::pyclass] +struct PyStream { + pub streamer: Arc>, +} + +impl PyStream { + fn new(streamer: Streamer) -> Self { + PyStream { + streamer: Arc::new(Mutex::new(streamer)), + } + } +} + +#[pymethods] +impl PyStream { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__(&self, py: Python) -> PyResult> { + let streamer = self.streamer.clone(); + let future = pyo3_async_runtimes::tokio::future_into_py(py, async move { + let val = streamer.lock().await.next().await; + match val { + Some(Ok(val)) => Ok(val), + Some(Err(val)) => Err(PyRuntimeError::new_err(val.to_string())), + None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), + } + }); + Ok(Some(future?.into())) + } +} + +struct ValueStreamer { + rx: mpsc::Receiver>>, +} + +impl ValueStreamer { + fn new( + rx: mpsc::Receiver>>, + ) -> Self { + ValueStreamer { rx } + } +} + +impl Stream for ValueStreamer { + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +#[pyo3::pyclass] +struct PyValueStream { + pub streamer: Arc>, +} + +impl PyValueStream { + fn new(streamer: ValueStreamer) -> Self { + PyValueStream { + streamer: Arc::new(Mutex::new(streamer)), + } + } +} + +#[pymethods] +impl PyValueStream { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__(&self, py: Python) -> PyResult> { + let streamer = self.streamer.clone(); + let future = pyo3_async_runtimes::tokio::future_into_py(py, async move { + let val = streamer.lock().await.next().await; + match val { + Some(Ok(val)) => Ok(val.to_string()), + Some(Err(val)) => Err(PyRuntimeError::new_err(val.to_string())), + None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), + } + }); + Ok(Some(future?.into())) + } +} + +#[pymethods] +impl ConnectedSrpcClient { + pub fn send_message<'p>( + &'p self, + py: Python<'p>, + message: String, + ) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; + client + .send_message(&message) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) } - pub fn connect<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn send_message_and_check<'p>( + &'p self, + py: Python<'p>, + message: String, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - client.lock().await.connect().await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; + client + .send_message_and_check(&message) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } - pub fn send_message<'p>(&self, py: Python<'p>, message: String) -> PyResult<&'p PyAny> { + pub fn receive_message<'p>( + &self, + py: Python<'p>, + expect_empty: bool, + should_continue: bool, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - client.lock().await.send_message(&message).await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; + let rx = client + .receive_message( + expect_empty, + move |_| should_continue, + &ReceiveOptions::default(), + ) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Python::with_gil(|_py| PyStream::new(Streamer::new(rx)))) }) } - pub fn receive_message<'p>(&self, py: Python<'p>, expect_empty: bool) -> PyResult<&'p PyAny> { + pub fn receive_message_cb<'p>( + &self, + py: Python<'p>, + expect_empty: bool, + should_continue: Py, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client.lock().await.receive_message(expect_empty, |_| false).await + let should_continue = should_continue.clone_ref(py); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let should_continue = move |response: &str| -> bool { + Python::with_gil(|py| { + should_continue + .call1(py, (response,)) + .and_then(|v| v.extract::(py)) + .unwrap_or(false) + }) + }; + let client = client.lock().await; + let rx = client + .receive_message(expect_empty, should_continue, &ReceiveOptions::default()) + .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - - let mut results = Vec::new(); - while let Some(result) = rx.recv().await { - match result { - Ok(message) => results.push(message), - Err(e) => return Err(PyRuntimeError::new_err(e.to_string())), - } - } - Ok(Python::with_gil(|py| results.to_object(py))) + + Ok(Python::with_gil(|_py| PyStream::new(Streamer::new(rx)))) }) } - pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult<&'p PyAny> { + pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - let value: Value = serde_json::from_str(&payload).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - client.lock().await.send_json(&value).await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value: Value = serde_json::from_str(&payload) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let client = client.lock().await; + client + .send_json(&value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } - pub fn receive_json<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn send_json_and_check<'p>( + &self, + py: Python<'p>, + payload: String, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client.lock().await.receive_json(|_| false).await + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value: Value = serde_json::from_str(&payload) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - - let mut results = Vec::new(); - while let Some(result) = rx.recv().await { - match result { - Ok(json_value) => results.push(json_value.to_string()), - Err(e) => return Err(PyRuntimeError::new_err(e.to_string())), - } - } - Ok(Python::with_gil(|py| results.to_object(py))) + let client = client.lock().await; + client + .send_json_and_check(&value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } + + pub fn receive_json<'p>( + &self, + py: Python<'p>, + should_continue: bool, + ) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; + let rx = client + .receive_json(move |_| should_continue, &ReceiveOptions::default()) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Python::with_gil(|_py| { + PyValueStream::new(ValueStreamer::new(rx)) + })) + }) + } + + #[pyo3(signature = (should_continue, opts=None))] + pub fn receive_json_cb<'p>( + &self, + py: Python<'p>, + should_continue: Py, + opts: Option, + ) -> PyResult> { + let client = self.0.clone(); + let should_continue = should_continue.clone_ref(py); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let should_continue = move |response: &str| -> bool { + Python::with_gil(|py| { + should_continue + .call1(py, (response,)) + .and_then(|v| v.extract::(py)) + .unwrap_or(false) + }) + }; + let client = client.lock().await; + let rx = client + .receive_json(should_continue, &opts.unwrap_or_default()) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Python::with_gil(|_py| { + PyValueStream::new(ValueStreamer::new(rx)) + })) + }) + } + + pub fn request_reply<'p>( + &self, + py: Python<'p>, + method: String, + payload: String, + ) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value: Value = serde_json::from_str(&payload) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let client = client.lock().await; + let response = client + .request_reply::(&method, value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + // TODO: Figure out how to marshall this as a python dict + Ok(response.to_string()) + }) + } + + pub fn call<'p>(&self, py: Python<'p>, method: String) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let guard = client.lock_owned().await; + let conn = crate::call(guard, &method) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(SrpcMethodCallConn(Arc::new( + Mutex::new(conn), + ))) + }) + } +} + +#[pymethods] +impl SrpcMethodCallConn { + pub fn decode<'p>(&self, py: Python<'p>) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = client.lock_owned().await; + let response = guard + .decode() + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(response.to_string()) + }) + } + + pub fn close<'p>(&mut self, py: Python<'p>) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = client.lock_owned().await; + guard.close(); + Ok(()) }) } } diff --git a/rust/lib/srpc/client/src/tests.rs b/rust/lib/srpc/client/src/tests.rs new file mode 100644 index 00000000..e629be4c --- /dev/null +++ b/rust/lib/srpc/client/src/tests.rs @@ -0,0 +1 @@ +mod lib; diff --git a/rust/lib/srpc/client/src/tests/lib.rs b/rust/lib/srpc/client/src/tests/lib.rs new file mode 100644 index 00000000..5be614ec --- /dev/null +++ b/rust/lib/srpc/client/src/tests/lib.rs @@ -0,0 +1,113 @@ +use std::{error::Error, num::NonZeroU8}; + +use crate::{ClientConfig, ConnectedClient, ReceiveOptions}; + +use rstest::rstest; +use tokio::{ + io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream}, + sync::mpsc, +}; + +fn setup_test_client() -> (ConnectedClient, DuplexStream) { + let (client_stream, server_stream) = duplex(1024); + + let config = ClientConfig::new("example.com", 443, "/", "", ""); + (ConnectedClient::new(config, client_stream), server_stream) +} + +fn n_message(num: NonZeroU8) -> impl FnMut(&str) -> bool { + let mut seen = 0; + move |_msg: &str| { + if seen + 1 == num.get() { + false + } else { + seen += 1; + true + } + } +} + +fn one_message() -> impl FnMut(&str) -> bool { + n_message(NonZeroU8::new(1).unwrap()) +} + +async fn check_message( + server_message: &str, + rx: &mut mpsc::Receiver>>, +) { + if let Some(Ok(received_msg)) = rx.recv().await { + assert_eq!(received_msg, server_message.trim()); + } else { + panic!("Did not receive expected message from server"); + } +} + +async fn check_server( + client_message: &str, + server_stream: &mut DuplexStream, +) -> Result<(), Box> { + let mut server_buf = vec![0u8; client_message.len()]; + server_stream.read_exact(&mut server_buf).await?; + assert_eq!(&server_buf, client_message.as_bytes()); + Ok(()) +} + +#[test_log::test(rstest)] +#[tokio::test(start_paused = true)] +async fn test_connected_client_send_and_receive() -> Result<(), Box> { + let (connected_client, mut server_stream) = setup_test_client(); + + let client_message = "Hello from client\n"; + connected_client.send_message(client_message).await?; + + check_server(client_message, &mut server_stream).await?; + + let server_message = "Hello from server\n"; + server_stream.write_all(server_message.as_bytes()).await?; + + let should_continue = one_message(); + + let opts = ReceiveOptions::default(); + let mut rx = connected_client + .receive_message(false, should_continue, &opts) + .await?; + + check_message(server_message, &mut rx).await; + + Ok(()) +} + +#[test_log::test(rstest)] +#[tokio::test(start_paused = true)] +async fn test_connected_client_send_and_receive_stream() -> Result<(), Box> { + let (connected_client, mut server_stream) = setup_test_client(); + + let client_message = "Hello from client\n"; + connected_client.send_message(client_message).await?; + + check_server(client_message, &mut server_stream).await?; + + server_stream.write_all("\n".as_bytes()).await?; + + let should_continue = one_message(); + + let opts = ReceiveOptions::default(); + let mut rx = connected_client + .receive_message(true, should_continue, &opts) + .await?; + + check_message("", &mut rx).await; + + server_stream.write_all("first\n".as_bytes()).await?; + + server_stream.write_all("second\n".as_bytes()).await?; + + let should_continue = n_message(NonZeroU8::new(2).unwrap()); + let mut rx = connected_client + .receive_message(false, should_continue, &opts) + .await?; + + check_message("first", &mut rx).await; + check_message("second", &mut rx).await; + Ok(()) +} diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi new file mode 100644 index 00000000..abed9fe1 --- /dev/null +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -0,0 +1,23 @@ +from typing import Callable, List + +type JsonStr = str + +class SrpcClientConfig: + def __init__(self, host: str, port: int, path: str, cert: str, key: str) -> None: ... + async def connect(self) -> "ConnectedSrpcClient": ... + +class ConnectedSrpcClient: + async def send_message(self, message: str) -> None: ... + async def send_message_and_check(self, message: str) -> None: ... + async def receive_message(self, expect_empty: bool, should_continue: bool) -> List[str]: ... + async def receive_message_cb(self, expect_empty: bool, should_continue: Callable[[str], bool]) -> List[str]: ... + async def send_json(self, payload: str) -> None: ... + async def send_json_and_check(self, payload: str) -> None: ... + async def receive_json(self, should_continue: bool) -> List[str]: ... + async def receive_json_cb(self, should_continue: Callable[[str], bool]) -> List[str]: ... + async def request_reply(self, message: str, payload: JsonStr) -> JsonStr: ... + async def call(self, message: str) -> "SrpcMethodCallConn": ... + +class SrpcMethodCallConn: + async def decode(self) -> JsonStr: ... + async def close(self) -> None: ...