@@ -68,7 +68,7 @@ use std::{
rc::Rc,
sync::{
Arc, OnceLock,
- atomic::{AtomicBool, Ordering::SeqCst},
+ atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
},
time::{Duration, Instant},
};
@@ -89,10 +89,36 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
+const MAX_CONCURRENT_CONNECTIONS: usize = 512;
+
+static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
type MessageHandler =
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
+pub struct ConnectionGuard;
+
+impl ConnectionGuard {
+ pub fn try_acquire() -> Result<Self, ()> {
+ let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
+ if current_connections >= MAX_CONCURRENT_CONNECTIONS {
+ CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
+ tracing::error!(
+ "too many concurrent connections: {}",
+ current_connections + 1
+ );
+ return Err(());
+ }
+ Ok(ConnectionGuard)
+ }
+}
+
+impl Drop for ConnectionGuard {
+ fn drop(&mut self) {
+ CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
+ }
+}
+
struct Response<R> {
peer: Arc<Peer>,
receipt: Receipt<R>,
@@ -725,6 +751,7 @@ impl Server {
system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
+ connection_guard: Option<ConnectionGuard>,
) -> impl Future<Output = ()> + use<> {
let this = self.clone();
let span = info_span!("handle connection", %address,
@@ -745,6 +772,7 @@ impl Server {
tracing::error!("server is tearing down");
return
}
+
let (connection_id, handle_io, mut incoming_rx) = this
.peer
.add_connection(connection, {
@@ -786,6 +814,7 @@ impl Server {
tracing::error!(?error, "failed to send initial client update");
return;
}
+ drop(connection_guard);
let handle_io = handle_io.fuse();
futures::pin_mut!(handle_io);
@@ -1157,6 +1186,19 @@ pub async fn handle_websocket_request(
}
let socket_address = socket_address.to_string();
+
+ // Acquire connection guard before WebSocket upgrade
+ let connection_guard = match ConnectionGuard::try_acquire() {
+ Ok(guard) => guard,
+ Err(()) => {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ "Too many concurrent connections",
+ )
+ .into_response();
+ }
+ };
+
ws.on_upgrade(move |socket| {
let socket = socket
.map_ok(to_tungstenite_message)
@@ -1174,6 +1216,7 @@ pub async fn handle_websocket_request(
system_id_header.map(|header| header.to_string()),
None,
Executor::Production,
+ Some(connection_guard),
)
.await;
}
@@ -258,6 +258,7 @@ impl TestServer {
None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
+ None,
))
.detach();
let connection_id = connection_id_rx.await.map_err(|e| {