Merge pull request #2256 from zed-industries/liveness-probe

Antonio Scandurra created

Introduce Kubernetes liveness probe to ensure database works

Change summary

crates/collab/k8s/manifest.template.yml |  7 +++++++
crates/collab/src/main.rs               | 15 +++++++++++++--
2 files changed, 20 insertions(+), 2 deletions(-)

Detailed changes

crates/collab/k8s/manifest.template.yml 🔗

@@ -59,6 +59,13 @@ spec:
           ports:
             - containerPort: 8080
               protocol: TCP
+          livenessProbe:
+            httpGet:
+              path: /healthz
+              port: 8080
+            initialDelaySeconds: 5
+            periodSeconds: 5
+            timeoutSeconds: 5
           readinessProbe:
             httpGet:
               path: /

crates/collab/src/main.rs 🔗

@@ -1,11 +1,12 @@
 use anyhow::anyhow;
-use axum::{routing::get, Router};
+use axum::{routing::get, Extension, Router};
 use collab::{db, env, executor::Executor, AppState, Config, MigrateConfig, Result};
 use db::Database;
 use std::{
     env::args,
     net::{SocketAddr, TcpListener},
     path::Path,
+    sync::Arc,
 };
 use tokio::signal::unix::SignalKind;
 use tracing_log::LogTracer;
@@ -66,7 +67,12 @@ async fn main() -> Result<()> {
 
             let app = collab::api::routes(rpc_server.clone(), state.clone())
                 .merge(collab::rpc::routes(rpc_server.clone()))
-                .merge(Router::new().route("/", get(handle_root)));
+                .merge(
+                    Router::new()
+                        .route("/", get(handle_root))
+                        .route("/healthz", get(handle_liveness_probe))
+                        .layer(Extension(state.clone())),
+                );
 
             axum::Server::from_tcp(listener)?
                 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
@@ -95,6 +101,11 @@ async fn handle_root() -> String {
     format!("collab v{VERSION}")
 }
 
+async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
+    state.db.get_all_users(0, 1).await?;
+    Ok("ok".to_string())
+}
+
 pub fn init_tracing(config: &Config) -> Option<()> {
     use std::str::FromStr;
     use tracing_subscriber::layer::SubscriberExt;