@@ -22,7 +22,7 @@ use time::OffsetDateTime;
use tower::ServiceBuilder;
use tracing::instrument;
-pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body> {
+pub fn routes(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) -> Router<Body> {
Router::new()
.route("/user", get(get_authenticated_user))
.route("/users", get(get_users).post(create_user))
@@ -50,7 +50,7 @@ pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Bod
.layer(
ServiceBuilder::new()
.layer(Extension(state))
- .layer(Extension(rpc_server.clone()))
+ .layer(Extension(rpc_server))
.layer(middleware::from_fn(validate_api_token)),
)
}
@@ -9,6 +9,7 @@ mod db_tests;
#[cfg(test)]
mod integration_tests;
+use crate::rpc::ResultExt as _;
use axum::{body::Body, Router};
use collab::{Error, Result};
use db::{Db, PostgresDb};
@@ -18,6 +19,7 @@ use std::{
sync::Arc,
time::Duration,
};
+use tokio::signal;
use tracing_log::LogTracer;
use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
use util::ResultExt;
@@ -92,11 +94,12 @@ async fn main() -> Result<()> {
rpc_server.start_recording_project_activity(Duration::from_secs(5 * 60), rpc::RealExecutor);
let app = Router::<Body>::new()
- .merge(api::routes(&rpc_server, state.clone()))
- .merge(rpc::routes(rpc_server));
+ .merge(api::routes(rpc_server.clone(), state.clone()))
+ .merge(rpc::routes(rpc_server.clone()));
axum::Server::from_tcp(listener)?
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
+ .with_graceful_shutdown(graceful_shutdown(rpc_server, state))
.await?;
Ok(())
@@ -133,3 +136,52 @@ pub fn init_tracing(config: &Config) -> Option<()> {
None
}
+
+async fn graceful_shutdown(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) {
+ let ctrl_c = async {
+ signal::ctrl_c()
+ .await
+ .expect("failed to install Ctrl+C handler");
+ };
+
+ #[cfg(unix)]
+ let terminate = async {
+ signal::unix::signal(signal::unix::SignalKind::terminate())
+ .expect("failed to install signal handler")
+ .recv()
+ .await;
+ };
+
+ #[cfg(not(unix))]
+ let terminate = std::future::pending::<()>();
+
+ tokio::select! {
+ _ = ctrl_c => {},
+ _ = terminate => {},
+ }
+
+ if let Some(live_kit) = state.live_kit_client.as_ref() {
+ let deletions = rpc_server
+ .store()
+ .await
+ .rooms()
+ .values()
+ .map(|room| {
+ let name = room.live_kit_room.clone();
+ async {
+ live_kit.delete_room(name).await.trace_err();
+ }
+ })
+ .collect::<Vec<_>>();
+
+ tracing::info!("deleting all live-kit rooms");
+ if let Err(_) = tokio::time::timeout(
+ Duration::from_secs(10),
+ futures::future::join_all(deletions),
+ )
+ .await
+ {
+ tracing::error!("timed out waiting for live-kit room deletion");
+ }
+ }
+}
@@ -519,6 +519,10 @@ impl Store {
self.rooms.get(&room_id)
}
+ pub fn rooms(&self) -> &BTreeMap<RoomId, proto::Room> {
+ &self.rooms
+ }
+
pub fn call(
&mut self,
room_id: RoomId,