Skip to content

Pool Postgres connections #3043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ edition = { workspace = true }
[dependencies]
anyhow = { workspace = true }
chrono = "0.4"
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
native-tls = "0.2"
postgres-native-tls = "0.5"
spin-core = { path = "../core" }
115 changes: 74 additions & 41 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,83 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::spin::postgres::postgres::{
self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
};
use tokio_postgres::types::Type;
use tokio_postgres::{config::SslMode, types::ToSql, Row};
use tokio_postgres::{Client as TokioClient, NoTls, Socket};
use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row};

const CONNECTION_POOL_SIZE: usize = 64;

#[async_trait]
pub trait Client {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized;
pub trait ClientFactory: Send + Sync {
type Client: Client + Send + Sync + 'static;
fn new() -> Self;
async fn build_client(&mut self, address: &str) -> Result<Self::Client>;
}

pub struct PooledTokioClientFactory {
pools: std::collections::HashMap<String, deadpool_postgres::Pool>,
}

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;
fn new() -> Self {
Self {
pools: Default::default(),
}
}
async fn build_client(&mut self, address: &str) -> Result<Self::Client> {
let pool_entry = self.pools.entry(address.to_owned());
let pool = match pool_entry {
std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(),
std::collections::hash_map::Entry::Vacant(entry) => {
let pool = create_connection_pool(address)
.context("establishing PostgreSQL connection pool")?;
entry.insert(pool)
}
};

Ok(pool.get().await?)
}
}

fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
let config = address
.parse::<tokio_postgres::Config>()
.context("parsing Postgres connection string")?;

tracing::debug!("Build new connection: {}", address);

// TODO: This is slower but safer. Is it the right tradeoff?
// https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
let mgr_config = deadpool_postgres::ManagerConfig {
recycling_method: deadpool_postgres::RecyclingMethod::Clean,
};

let mgr = if config.get_ssl_mode() == SslMode::Disable {
deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
deadpool_postgres::Manager::from_config(config, connector, mgr_config)
};

// TODO: what is our max size heuristic? Should this be passed in soe that different
// hosts can manage it according to their needs? Will a plain number suffice for
// sophisticated hosts anyway?
let pool = deadpool_postgres::Pool::builder(mgr)
.max_size(CONNECTION_POOL_SIZE)
.build()
.context("building Postgres connection pool")?;

Ok(pool)
}

#[async_trait]
pub trait Client {
async fn execute(
&self,
statement: String,
@@ -29,28 +92,7 @@ pub trait Client {
}

#[async_trait]
impl Client for TokioClient {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized,
{
let config = address.parse::<tokio_postgres::Config>()?;

tracing::debug!("Build new connection: {}", address);

if config.get_ssl_mode() == SslMode::Disable {
let (client, connection) = config.connect(NoTls).await?;
spawn_connection(connection);
Ok(client)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
let (client, connection) = config.connect(connector).await?;
spawn_connection(connection);
Ok(client)
}
}

impl Client for deadpool_postgres::Object {
async fn execute(
&self,
statement: String,
@@ -67,7 +109,8 @@ impl Client for TokioClient {
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
.collect();

self.execute(&statement, params_refs.as_slice())
self.as_ref()
.execute(&statement, params_refs.as_slice())
.await
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
}
@@ -89,6 +132,7 @@ impl Client for TokioClient {
.collect();

let results = self
.as_ref()
.query(&statement, params_refs.as_slice())
.await
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
@@ -111,17 +155,6 @@ impl Client for TokioClient {
}
}

fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
where
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
{
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Postgres connection error: {}", e);
}
});
}

fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
match value {
ParameterValue::Boolean(v) => Ok(Box::new(*v)),
25 changes: 14 additions & 11 deletions crates/factor-outbound-pg/src/host.rs
Original file line number Diff line number Diff line change
@@ -9,17 +9,20 @@ use tracing::field::Empty;
use tracing::instrument;
use tracing::Level;

use crate::client::Client;
use crate::client::{Client, ClientFactory};
use crate::InstanceState;

impl<C: Client> InstanceState<C> {
impl<CF: ClientFactory> InstanceState<CF> {
async fn open_connection<Conn: 'static>(
&mut self,
address: &str,
) -> Result<Resource<Conn>, v3::Error> {
self.connections
.push(
C::build_client(address)
self.client_factory
.write()
.await
.build_client(address)
.await
.map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
)
@@ -30,7 +33,7 @@ impl<C: Client> InstanceState<C> {
async fn get_client<Conn: 'static>(
&mut self,
connection: Resource<Conn>,
) -> Result<&C, v3::Error> {
) -> Result<&CF::Client, v3::Error> {
self.connections
.get(connection.rep())
.ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
@@ -71,8 +74,8 @@ fn v2_params_to_v3(
params.into_iter().map(|p| p.try_into()).collect()
}

impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
for InstanceState<C>
impl<CF: ClientFactory + Send + Sync> spin_world::spin::postgres::postgres::HostConnection
for InstanceState<CF>
{
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
@@ -122,13 +125,13 @@ impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnecti
}
}

impl<C: Send> v2_types::Host for InstanceState<C> {
impl<CF: ClientFactory + Send> v2_types::Host for InstanceState<CF> {
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
Ok(error)
}
}

impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v3::Host for InstanceState<CF> {
fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
Ok(error)
}
@@ -152,9 +155,9 @@ macro_rules! delegate {
}};
}

impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
impl<CF: Send + Sync + ClientFactory> v2::Host for InstanceState<CF> {}

impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v2::HostConnection for InstanceState<CF> {
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
spin_factor_outbound_networking::record_address_fields(&address);
@@ -206,7 +209,7 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
}
}

impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
impl<CF: Send + Sync + ClientFactory> v1::Host for InstanceState<CF> {
async fn execute(
&mut self,
address: String,
26 changes: 15 additions & 11 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
pub mod client;
mod host;

use client::Client;
use std::sync::Arc;

use client::ClientFactory;
use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor};
use spin_factors::{
anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
};
use tokio_postgres::Client as PgClient;
use tokio::sync::RwLock;

pub struct OutboundPgFactor<C = PgClient> {
_phantom: std::marker::PhantomData<C>,
pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {
_phantom: std::marker::PhantomData<CF>,
}

impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
impl<CF: ClientFactory + Send + Sync + 'static> Factor for OutboundPgFactor<CF> {
type RuntimeConfig = ();
type AppState = ();
type InstanceBuilder = InstanceState<C>;
type AppState = Arc<RwLock<CF>>;
type InstanceBuilder = InstanceState<CF>;

fn init<T: Send + 'static>(
&mut self,
@@ -31,7 +33,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
&self,
_ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(())
Ok(Arc::new(RwLock::new(CF::new())))
}

fn prepare<T: RuntimeFactors>(
@@ -43,6 +45,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
.allowed_hosts();
Ok(InstanceState {
allowed_hosts,
client_factory: ctx.app_state().clone(),
connections: Default::default(),
})
}
@@ -62,9 +65,10 @@ impl<C> OutboundPgFactor<C> {
}
}

pub struct InstanceState<C> {
pub struct InstanceState<CF: ClientFactory> {
allowed_hosts: OutboundAllowedHosts,
connections: spin_resource_table::Table<C>,
client_factory: Arc<RwLock<CF>>,
connections: spin_resource_table::Table<CF::Client>,
}

impl<C: Send + 'static> SelfInstanceBuilder for InstanceState<C> {}
impl<CF: ClientFactory + Send + 'static> SelfInstanceBuilder for InstanceState<CF> {}
20 changes: 13 additions & 7 deletions crates/factor-outbound-pg/tests/factor_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{bail, Result};
use spin_factor_outbound_networking::OutboundNetworkingFactor;
use spin_factor_outbound_pg::client::Client;
use spin_factor_outbound_pg::client::ClientFactory;
use spin_factor_outbound_pg::OutboundPgFactor;
use spin_factor_variables::VariablesFactor;
use spin_factors::{anyhow, RuntimeFactors};
@@ -15,14 +16,14 @@ use spin_world::spin::postgres::postgres::{ParameterValue, RowSet};
struct TestFactors {
variables: VariablesFactor,
networking: OutboundNetworkingFactor,
pg: OutboundPgFactor<MockClient>,
pg: OutboundPgFactor<MockClientFactory>,
}

fn factors() -> TestFactors {
TestFactors {
variables: VariablesFactor::default(),
networking: OutboundNetworkingFactor::new(),
pg: OutboundPgFactor::<MockClient>::new(),
pg: OutboundPgFactor::<MockClientFactory>::new(),
}
}

@@ -104,17 +105,22 @@ async fn exercise_query() -> anyhow::Result<()> {
}

// TODO: We can expand this mock to track calls and simulate return values
pub struct MockClientFactory {}
pub struct MockClient {}

#[async_trait]
impl Client for MockClient {
async fn build_client(_address: &str) -> anyhow::Result<Self>
where
Self: Sized,
{
impl ClientFactory for MockClientFactory {
type Client = MockClient;
fn new() -> Self {
Self {}
}
async fn build_client(&mut self, _address: &str) -> Result<Self::Client> {
Ok(MockClient {})
}
}

#[async_trait]
impl Client for MockClient {
async fn execute(
&self,
_statement: String,