Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use crate::proxy::Proxy;
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Response;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::sync::RwLock;
use tracing::debug;
use tracing::error;
use tracing::info;
Expand Down Expand Up @@ -66,11 +67,7 @@ impl Proxy for HttpProxy {
}

/// Handles the proxying of HTTP connections to configured targets.
async fn proxy(
&self,
mut connection: Connection,
routing_idx: Arc<RwLock<usize>>,
) -> Result<()> {
async fn proxy(&self, mut connection: Connection, routing_idx: Arc<AtomicUsize>) -> Result<()> {
if let Some(backends) = connection.targets.get(&connection.target_name) {
let backend_count = backends.len();
if backend_count == 0 {
Expand Down Expand Up @@ -103,31 +100,28 @@ impl Proxy for HttpProxy {
let method = http_info[0].clone();
let request_path = http_info[1].clone();

// Limit the scope of the index write lock.
let http_backend: String;
{
let mut idx = routing_idx.write().await;

debug!(
"[{}] {backend_count} backends configured for {}, current index {idx}",
self.protocol_type(),
&connection.target_name
);
// Reset index when out of bounds to route back to the first server.
if routing_idx.load(Ordering::Acquire) >= backend_count {
debug!("Routing index reset");
routing_idx.store(0, Ordering::Relaxed);
}

// Reset index when out of bounds to route back to the first server.
if *idx >= backend_count {
*idx = 0;
}
let backend_idx = routing_idx.load(Ordering::Relaxed);
let http_backend = format!(
"http://{}:{}{}",
backends[backend_idx].host, backends[backend_idx].port, request_path
);

http_backend = format!(
"http://{}:{}{}",
backends[*idx].host, backends[*idx].port, request_path
);
debug!(
"[{}] {backend_count} backends configured for {}, current index {}",
self.protocol_type(),
&connection.target_name,
routing_idx.load(Ordering::Relaxed),
);

// Increment a shared index after we've constructed our current connection
// address.
*idx += 1;
}
// Increment a shared index after we've constructed our current connection
// address.
routing_idx.fetch_add(1, Ordering::Relaxed);

info!(
"[{}] Attempting to connect to {}",
Expand Down
57 changes: 27 additions & 30 deletions src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use crate::proxy::Proxy;
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Response;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::sync::RwLock;
use tracing::debug;
use tracing::error;
use tracing::info;
Expand Down Expand Up @@ -62,7 +63,7 @@ impl Proxy for HttpsProxy {
}

/// Handles the proxying of HTTP connections to configured targets.
async fn proxy(&self, connection: Connection, routing_idx: Arc<RwLock<usize>>) -> Result<()> {
async fn proxy(&self, connection: Connection, routing_idx: Arc<AtomicUsize>) -> Result<()> {
if let Some(backends) = connection.targets.get(&connection.target_name) {
let backend_count = backends.len();
if backend_count == 0 {
Expand Down Expand Up @@ -102,36 +103,32 @@ impl Proxy for HttpsProxy {
let method = http_info[0].clone();
let request_path = http_info[1].clone();

// Limit the scope of the index write lock.
let http_backend: String;
{
let mut idx = routing_idx.write().await;

debug!(
"[{}] {backend_count} backends configured for {}, current index {idx}",
self.protocol_type(),
&connection.target_name
);
// Reset index when out of bounds to route back to the first server.
if routing_idx.load(Ordering::Acquire) >= backend_count {
routing_idx.store(0, Ordering::Relaxed);
}

// Reset index when out of bounds to route back to the first server.
if *idx >= backend_count {
*idx = 0;
}
let backend_idx = routing_idx.load(Ordering::Relaxed);
let https_backend = format!(
"https://{}:{}{}",
backends[backend_idx].host, backends[backend_idx].port, request_path
);

http_backend = format!(
"http://{}:{}{}",
backends[*idx].host, backends[*idx].port, request_path
);
debug!(
"[{}] {backend_count} backends configured for {}, current index {}",
self.protocol_type(),
&connection.target_name,
routing_idx.load(Ordering::Relaxed),
);

// Increment a shared index after we've constructed our current connection
// address.
*idx += 1;
}
// Increment a shared index after we've constructed our current connection
// address.
routing_idx.fetch_add(1, Ordering::Relaxed);

info!(
"[{}] Attempting to connect to {}",
self.protocol_type(),
&http_backend
&https_backend
);

match method.as_str() {
Expand All @@ -140,11 +137,11 @@ impl Proxy for HttpsProxy {
.client
.as_ref()
.unwrap()
.get(&http_backend)
.get(&https_backend)
.send()
.await
.with_context(|| {
format!("unable to send response to {http_backend}")
format!("unable to send response to {https_backend}")
})?;
let response = self.construct_response(backend_response).await?;

Expand All @@ -155,11 +152,11 @@ impl Proxy for HttpsProxy {
.client
.as_ref()
.unwrap()
.post(&http_backend)
.post(&https_backend)
.send()
.await
.with_context(|| {
format!("unable to send response to {http_backend}")
format!("unable to send response to {https_backend}")
})?;
let response = self.construct_response(backend_response).await?;

Expand All @@ -172,7 +169,7 @@ impl Proxy for HttpsProxy {
info!(
"[{}] response sent to {}",
self.protocol_type(),
&http_backend
&https_backend
);
}
Err(e) => error!("unable to accept TLS stream: {e}"),
Expand Down
10 changes: 3 additions & 7 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use anyhow::Result;
use async_trait::async_trait;
use dashmap::DashMap;
use std::iter::Iterator;
use std::sync::atomic::AtomicUsize;
use std::{sync::Arc, vec};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio_native_tls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
Expand All @@ -20,11 +20,7 @@ use tracing::{error, info};
pub trait Proxy: Send + Sync + Copy + 'static {
/// Proxy a `TcpStream` from an incoming connection to configured targets, with accompanying
/// `Connection` related data.
async fn proxy(
&self,
mut connection: Connection,
routing_idx: Arc<RwLock<usize>>,
) -> Result<()>;
async fn proxy(&self, mut connection: Connection, routing_idx: Arc<AtomicUsize>) -> Result<()>;

/// Retrieve the type of protocol in use by the current proxy.
fn protocol_type(&self) -> Protocol;
Expand All @@ -42,7 +38,7 @@ pub async fn accept<P>(
where
P: Proxy,
{
let idx: Arc<RwLock<usize>> = Arc::new(RwLock::new(0));
let idx: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
for conf in listeners {
let idx = idx.clone();

Expand Down
44 changes: 20 additions & 24 deletions src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use crate::proxy::Connection;
use crate::proxy::Proxy;
use anyhow::Result;
use async_trait::async_trait;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tracing::debug;
use tracing::info;

Expand Down Expand Up @@ -39,11 +40,7 @@ impl Proxy for TcpProxy {
}

/// Handles the proxying of TCP connections to configured targets.
async fn proxy(
&self,
mut connection: Connection,
routing_idx: Arc<RwLock<usize>>,
) -> Result<()> {
async fn proxy(&self, mut connection: Connection, routing_idx: Arc<AtomicUsize>) -> Result<()> {
if let Some(backends) = connection.targets.get(&connection.target_name) {
let backend_count = backends.len();
if backend_count == 0 {
Expand All @@ -55,26 +52,25 @@ impl Proxy for TcpProxy {
}
debug!("Backends configured {:?}", &backends);

// Limit the scope of the index write lock.
let backend_addr: String;
{
let mut idx = routing_idx.write().await;
debug!(
"[TCP] {backend_count} backends configured for {}, current index {idx}",
&connection.target_name
);

// Reset index when out of bounds to route back to the first server.
if *idx >= backend_count {
*idx = 0;
}
// Reset index when out of bounds to route back to the first server.
if routing_idx.load(Ordering::Acquire) >= backend_count {
routing_idx.store(0, Ordering::Relaxed);
}

backend_addr = format!("{}:{}", backends[*idx].host, backends[*idx].port);
let backend_idx = routing_idx.load(Ordering::Relaxed);
let backend_addr = format!(
"{}:{}",
backends[backend_idx].host, backends[backend_idx].port
);
debug!(
"[TCP] {backend_count} backends configured for {}, current index {}",
&connection.target_name,
routing_idx.load(Ordering::Relaxed),
);

// Increment a shared index after we've constructed our current connection
// address.
*idx += 1;
}
// Increment a shared index after we've constructed our current connection
// address.
routing_idx.fetch_add(1, Ordering::Relaxed);

info!("[TCP] Attempting to connect to {}", &backend_addr);

Expand Down
2 changes: 1 addition & 1 deletion tests/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async fn route_to_healthy_targets() {

for _ in 0..=4 {
let response = http_client
.get("http://localhost:8080")
.get("http://127.0.0.1:8080")
.send()
.await
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion tests/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async fn route_tls_target() {

for _ in 0..=4 {
let response = https_client
.get("https://localhost:8443")
.get("https://127.0.0.1:8443")
.send()
.await
.unwrap();
Expand Down