Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ native-tls = "0.2"
postgres-native-tls = "0.5"
postgres_range = "0.11"
rust_decimal = { version = "1.37", features = ["db-tokio-postgres"] }
tokio-rustls = { workspace = true }
serde_json = { workspace = true }
spin-common = { path = "../common" }
spin-core = { path = "../core" }
spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-locked-app = { path = "../locked-app" }
spin-resource-table = { path = "../table" }
spin-world = { path = "../world" }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] }
tracing = { workspace = true }
url = { workspace = true }
uuid = "1"

[dev-dependencies]
Expand Down
114 changes: 103 additions & 11 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{Context, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_locked_app::locked::ContentPath;
use spin_world::async_trait;
use spin_world::spin::postgres4_0_0::postgres::{
self as v4, Column, DbValue, ParameterValue, RowSet,
Expand All @@ -21,19 +22,22 @@ const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16;
pub trait ClientFactory: Default + Send + Sync + 'static {
/// The type of client produced by `get_client`.
type Client: Client;
// fn new(component_tls_configs: ComponentTlsClientConfigs) -> Self;
/// Gets a client from the factory.
async fn get_client(&self, address: &str) -> Result<Self::Client>;
async fn get_client(&self, address: &str, assets: &[ContentPath]) -> Result<Self::Client>;
}

/// A `ClientFactory` that uses a connection pool per address.
pub struct PooledTokioClientFactory {
pools: moka::sync::Cache<String, deadpool_postgres::Pool>,
// component_tls_configs: Option<ComponentTlsClientConfigs>,
}

impl Default for PooledTokioClientFactory {
fn default() -> Self {
Self {
pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY),
// component_tls_configs: None,
}
}
}
Expand All @@ -42,21 +46,80 @@ impl Default for PooledTokioClientFactory {
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;

async fn get_client(&self, address: &str) -> Result<Self::Client> {
async fn get_client(&self, address: &str, assets: &[ContentPath]) -> Result<Self::Client> {
let pool = self
.pools
.try_get_with_by_ref(address, || create_connection_pool(address))
.try_get_with_by_ref(address, || create_connection_pool(address, assets))
.map_err(ArcError)
.context("establishing PostgreSQL connection pool")?;

Ok(pool.get().await?)
}
}

pub(crate) struct SuperConfig {
pub config: tokio_postgres::Config,
ssl_root_cert: Option<String>,
}

impl std::str::FromStr for SuperConfig {
type Err = <tokio_postgres::Config as std::str::FromStr>::Err;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let (ssl_root_cert, conn_str) =
if s.starts_with("postgresql://") || s.starts_with("postgres://") {
let url = url::Url::parse(s).unwrap();
let (sr, rest): (Vec<_>, _) = url.query_pairs().partition(|t| t.0 == "sslrootcert");
let mut url2 = url::Url::parse(s).unwrap();
let url2 = url2
.query_pairs_mut()
.clear()
.extend_pairs(rest)
.finish()
.to_string();
// let sr = url.query_pairs().find(|(k,v)| k == "sslrootcert").map(|t| t.1.to_string());
let sr = match sr.len() {
0 => None,
1 => Some(sr[0].1.to_string()),
_ => panic!("oh no"),
};
(sr, url2)
} else {
let bits = s.split(' ');
let (ssl_root_certs, rest): (Vec<_>, _) =
bits.partition(|bit| bit.strip_prefix("sslrootcert=").is_some());
let ssl_root_certs = ssl_root_certs
.into_iter()
.filter_map(|e| e.strip_prefix("sslrootcert="))
.collect::<Vec<_>>();
let sr = match ssl_root_certs.len() {
0 => None,
1 => Some(ssl_root_certs[0].to_owned()),
_ => panic!("oh no"),
};
let rest = rest.join(" ");
(sr, rest)
};

eprintln!("SSL ROOT: {ssl_root_cert:?}");
eprintln!("CONFIG: {conn_str}");

let config = conn_str.parse::<tokio_postgres::Config>()?;

Ok(Self {
ssl_root_cert,
config,
})
}
}

/// Creates a Postgres connection pool for the given address.
fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
let config = address
.parse::<tokio_postgres::Config>()
fn create_connection_pool(
address: &str,
assets: &[ContentPath],
) -> Result<deadpool_postgres::Pool> {
let super_config = address
.parse::<SuperConfig>()
.context("parsing Postgres connection string")?;

tracing::debug!("Build new connection: {}", address);
Expand All @@ -65,12 +128,41 @@ fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
recycling_method: deadpool_postgres::RecyclingMethod::Clean,
};

let mgr = if config.get_ssl_mode() == SslMode::Disable {
deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
let mgr = if super_config.config.get_ssl_mode() == SslMode::Disable {
deadpool_postgres::Manager::from_config(super_config.config, NoTls, mgr_config)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
deadpool_postgres::Manager::from_config(config, connector, mgr_config)
match super_config.ssl_root_cert.as_ref() {
None => deadpool_postgres::Manager::from_config(super_config.config, NoTls, mgr_config),
Some(ca_file_path) => {
if assets.len() == 1 && assets[0].path.display().to_string() == "/" {
// we are in a copy-mount scenario and can party on
} else {
anyhow::bail!("PostgreSQL sslrootcert is not yet supported with direct mounts")
};

let asset_root_url = assets[0]
.content
.source
.as_ref()
.context("LockedComponentSource missing source field")?;
let asset_root_dir = spin_common::url::parse_file_url(asset_root_url)?;

let ca_file_rel_path = ca_file_path.trim_start_matches('/');
let ca_file_abs_path = asset_root_dir.join(ca_file_rel_path);
eprintln!("CA FILE: {}", ca_file_abs_path.display());

if !ca_file_abs_path.is_file() {
anyhow::bail!("file {ca_file_path} does not exist in this component");
}

let cert_bytes = std::fs::read(&ca_file_abs_path)?;

let mut builder = TlsConnector::builder();
builder.add_root_certificate(native_tls::Certificate::from_pem(&cert_bytes)?);
let connector = MakeTlsConnector::new(builder.build()?);
deadpool_postgres::Manager::from_config(super_config.config, connector, mgr_config)
}
}
};

// TODO: what is our max size heuristic? Should this be passed in so that different
Expand Down
7 changes: 4 additions & 3 deletions crates/factor-outbound-pg/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl<CF: ClientFactory> InstanceState<CF> {
self.connections
.push(
self.client_factory
.get_client(address)
.get_client(address, &self.assets)
.await
.map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?,
)
Expand All @@ -48,8 +48,9 @@ impl<CF: ClientFactory> InstanceState<CF> {
}

let config = address
.parse::<tokio_postgres::Config>()
.map_err(|e| conn_failed(e.to_string()))?;
.parse::<crate::client::SuperConfig>()
.map_err(|e| conn_failed(e.to_string()))?
.config;

for (i, host) in config.get_hosts().iter().enumerate() {
match host {
Expand Down
24 changes: 16 additions & 8 deletions crates/factor-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod client;
mod host;
mod types;

use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use client::ClientFactory;
use spin_factor_outbound_networking::{
Expand All @@ -19,7 +19,7 @@ pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {

impl<CF: ClientFactory> Factor for OutboundPgFactor<CF> {
type RuntimeConfig = ();
type AppState = Arc<CF>;
type AppState = HashMap<String, Arc<CF>>;
type InstanceBuilder = InstanceState<CF>;

fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
Expand All @@ -36,22 +36,29 @@ impl<CF: ClientFactory> Factor for OutboundPgFactor<CF> {

fn configure_app<T: RuntimeFactors>(
&self,
_ctx: ConfigureAppContext<T, Self>,
ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
Ok(Arc::new(CF::default()))
let mut client_factories = HashMap::new();
for comp in ctx.app().components() {
client_factories.insert(comp.id().to_string(), Arc::new(CF::default()));
}
Ok(client_factories)
}

fn prepare<T: RuntimeFactors>(
&self,
mut ctx: PrepareContext<T, Self>,
) -> anyhow::Result<Self::InstanceBuilder> {
let allowed_hosts = ctx
.instance_builder::<OutboundNetworkingFactor>()?
.allowed_hosts();
let outbound_networking = ctx.instance_builder::<OutboundNetworkingFactor>()?;
let allowed_hosts = outbound_networking.allowed_hosts();
let cf = ctx.app_state().get(ctx.app_component().id()).unwrap();
let assets = ctx.app_component().files().cloned().collect();

Ok(InstanceState {
allowed_hosts,
client_factory: ctx.app_state().clone(),
client_factory: cf.clone(),
connections: Default::default(),
assets,
})
}
}
Expand All @@ -74,6 +81,7 @@ pub struct InstanceState<CF: ClientFactory> {
allowed_hosts: OutboundAllowedHosts,
client_factory: Arc<CF>,
connections: spin_resource_table::Table<CF::Client>,
assets: Vec<spin_locked_app::locked::ContentPath>,
}

impl<CF: ClientFactory> SelfInstanceBuilder for InstanceState<CF> {}
6 changes: 5 additions & 1 deletion crates/factor-outbound-pg/tests/factor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ pub struct MockClient {}
#[async_trait]
impl ClientFactory for MockClientFactory {
type Client = MockClient;
async fn get_client(&self, _address: &str) -> Result<Self::Client> {
async fn get_client(
&self,
_address: &str,
_assets: &[spin_locked_app::locked::ContentPath],
) -> Result<Self::Client> {
Ok(MockClient {})
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/manual/pg-ssl-root-certs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target/
.spin/
pg
postgres-ssl
Loading