Refactor: Restructure collab main function to prepare for new subcommand: `serve llm` (#15824)

Max Brunsfeld , Marshall , and Jason created

This is just a refactor that we're landing ahead of any functional
changes to make sure we haven't broken anything.

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Jason <jason@zed.dev>

Change summary

Procfile                  |   2 
crates/collab/src/api.rs  |   9 -
crates/collab/src/lib.rs  |  26 +++++
crates/collab/src/llm.rs  |  16 +++
crates/collab/src/main.rs | 178 +++++++++++++++++++++++-----------------
5 files changed, 144 insertions(+), 87 deletions(-)

Detailed changes

Procfile 🔗

@@ -1,3 +1,3 @@
-collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve
+collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
 livekit: livekit-server --dev
 blob_store: ./script/run-local-minio

crates/collab/src/api.rs 🔗

@@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
     }
 }
 
-pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
+pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
     Router::new()
         .route("/user", get(get_authenticated_user))
         .route("/users/:id/access_tokens", post(create_access_token))
@@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
         .merge(contributors::router())
         .layer(
             ServiceBuilder::new()
-                .layer(Extension(state))
                 .layer(Extension(rpc_server))
                 .layer(middleware::from_fn(validate_api_token)),
         )
@@ -152,12 +151,8 @@ struct CreateUserParams {
 }
 
 async fn get_rpc_server_snapshot(
-    Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>,
+    Extension(rpc_server): Extension<Arc<rpc::Server>>,
 ) -> Result<ErasedJson> {
-    let Some(rpc_server) = rpc_server else {
-        return Err(Error::Internal(anyhow!("rpc server is not available")));
-    };
-
     Ok(ErasedJson::pretty(rpc_server.snapshot().await))
 }
 

crates/collab/src/lib.rs 🔗

@@ -3,6 +3,7 @@ pub mod auth;
 pub mod db;
 pub mod env;
 pub mod executor;
+pub mod llm;
 mod rate_limiter;
 pub mod rpc;
 pub mod seed;
@@ -124,7 +125,7 @@ impl std::fmt::Display for Error {
 
 impl std::error::Error for Error {}
 
-#[derive(Deserialize)]
+#[derive(Clone, Deserialize)]
 pub struct Config {
     pub http_port: u16,
     pub database_url: String,
@@ -176,6 +177,29 @@ impl Config {
     }
 }
 
+/// The service mode that collab should run in.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum ServiceMode {
+    Api,
+    Collab,
+    Llm,
+    All,
+}
+
+impl ServiceMode {
+    pub fn is_collab(&self) -> bool {
+        matches!(self, Self::Collab | Self::All)
+    }
+
+    pub fn is_api(&self) -> bool {
+        matches!(self, Self::Api | Self::All)
+    }
+
+    pub fn is_llm(&self) -> bool {
+        matches!(self, Self::Llm | Self::All)
+    }
+}
+
 pub struct AppState {
     pub db: Arc<Database>,
     pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,

crates/collab/src/llm.rs 🔗

@@ -0,0 +1,16 @@
+use std::sync::Arc;
+
+use crate::{executor::Executor, Config, Result};
+
+pub struct LlmState {
+    pub config: Config,
+    pub executor: Executor,
+}
+
+impl LlmState {
+    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
+        let this = Self { config, executor };
+
+        Ok(Arc::new(this))
+    }
+}

crates/collab/src/main.rs 🔗

@@ -5,7 +5,7 @@ use axum::{
     routing::get,
     Extension, Router,
 };
-use collab::api::billing::poll_stripe_events_periodically;
+use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
 use collab::{
     api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
     rpc::ResultExt, AppState, Config, RateLimiter, Result,
@@ -56,88 +56,99 @@ async fn main() -> Result<()> {
             collab::seed::seed(&config, &db, true).await?;
         }
         Some("serve") => {
-            let (is_api, is_collab) = if let Some(next) = args.next() {
-                (next == "api", next == "collab")
-            } else {
-                (true, true)
+            let mode = match args.next().as_deref() {
+                Some("collab") => ServiceMode::Collab,
+                Some("api") => ServiceMode::Api,
+                Some("llm") => ServiceMode::Llm,
+                Some("all") => ServiceMode::All,
+                _ => {
+                    return Err(anyhow!(
+                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
+                    ))?;
+                }
             };
-            if !is_api && !is_collab {
-                Err(anyhow!(
-                    "usage: collab <version | migrate | seed | serve [api|collab]>"
-                ))?;
-            }
 
             let config = envy::from_env::<Config>().expect("error loading config");
             init_tracing(&config);
+            let mut app = Router::new()
+                .route("/", get(handle_root))
+                .route("/healthz", get(handle_liveness_probe))
+                .layer(Extension(mode));
 
-            run_migrations(&config).await?;
-
-            let state = AppState::new(config, Executor::Production).await?;
-
-            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
+            let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
                 .expect("failed to bind TCP listener");
 
-            let rpc_server = if is_collab {
-                let epoch = state
-                    .db
-                    .create_server(&state.config.zed_environment)
-                    .await?;
-                let rpc_server = collab::rpc::Server::new(epoch, state.clone());
-                rpc_server.start().await?;
-
-                Some(rpc_server)
-            } else {
-                None
-            };
+            let mut on_shutdown = None;
 
-            if is_collab {
-                state.db.purge_old_embeddings().await.trace_err();
-                RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
-            }
+            if mode.is_llm() {
+                let state = LlmState::new(config.clone(), Executor::Production).await?;
 
-            if is_api {
-                poll_stripe_events_periodically(state.clone());
-                fetch_extensions_from_blob_store_periodically(state.clone());
+                app = app.layer(Extension(state.clone()));
             }
 
-            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
-            if let Some(rpc_server) = rpc_server.clone() {
-                app = app.merge(collab::rpc::routes(rpc_server))
-            }
-            app = app
-                .merge(
-                    Router::new()
-                        .route("/", get(handle_root))
-                        .route("/healthz", get(handle_liveness_probe))
-                        .merge(collab::api::extensions::router())
+            if mode.is_collab() || mode.is_api() {
+                run_migrations(&config).await?;
+
+                let state = AppState::new(config, Executor::Production).await?;
+
+                if mode.is_collab() {
+                    state.db.purge_old_embeddings().await.trace_err();
+                    RateLimiter::save_periodically(
+                        state.rate_limiter.clone(),
+                        state.executor.clone(),
+                    );
+
+                    let epoch = state
+                        .db
+                        .create_server(&state.config.zed_environment)
+                        .await?;
+                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
+                    rpc_server.start().await?;
+
+                    app = app
+                        .merge(collab::api::routes(rpc_server.clone()))
+                        .merge(collab::rpc::routes(rpc_server.clone()));
+
+                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
+                }
+
+                if mode.is_api() {
+                    poll_stripe_events_periodically(state.clone());
+                    fetch_extensions_from_blob_store_periodically(state.clone());
+
+                    app = app
                         .merge(collab::api::events::router())
-                        .layer(Extension(state.clone())),
-                )
-                .layer(
-                    TraceLayer::new_for_http()
-                        .make_span_with(|request: &Request<_>| {
-                            let matched_path = request
-                                .extensions()
-                                .get::<MatchedPath>()
-                                .map(MatchedPath::as_str);
-
-                            tracing::info_span!(
-                                "http_request",
-                                method = ?request.method(),
-                                matched_path,
-                            )
-                        })
-                        .on_response(
-                            |response: &Response<_>, latency: Duration, _: &tracing::Span| {
-                                let duration_ms = latency.as_micros() as f64 / 1000.;
-                                tracing::info!(
-                                    duration_ms,
-                                    status = response.status().as_u16(),
-                                    "finished processing request"
-                                );
-                            },
-                        ),
-                );
+                        .merge(collab::api::extensions::router())
+                }
+
+                app = app.layer(Extension(state.clone()));
+            }
+
+            app = app.layer(
+                TraceLayer::new_for_http()
+                    .make_span_with(|request: &Request<_>| {
+                        let matched_path = request
+                            .extensions()
+                            .get::<MatchedPath>()
+                            .map(MatchedPath::as_str);
+
+                        tracing::info_span!(
+                            "http_request",
+                            method = ?request.method(),
+                            matched_path,
+                        )
+                    })
+                    .on_response(
+                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
+                            let duration_ms = latency.as_micros() as f64 / 1000.;
+                            tracing::info!(
+                                duration_ms,
+                                status = response.status().as_u16(),
+                                "finished processing request"
+                            );
+                        },
+                    ),
+            );
 
             #[cfg(unix)]
             let signal = async move {
@@ -174,8 +185,8 @@ async fn main() -> Result<()> {
                     signal.await;
                     tracing::info!("Received interrupt signal");
 
-                    if let Some(rpc_server) = rpc_server {
-                        rpc_server.teardown();
+                    if let Some(on_shutdown) = on_shutdown {
+                        on_shutdown();
                     }
                 })
                 .await
@@ -183,7 +194,7 @@ async fn main() -> Result<()> {
         }
         _ => {
             Err(anyhow!(
-                "usage: collab <version | migrate | seed | serve [api|collab]>"
+                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
             ))?;
         }
     }
@@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
     return Ok(());
 }
 
-async fn handle_root() -> String {
-    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
+async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
+    format!(
+        "collab {mode:?} v{VERSION} ({})",
+        REVISION.unwrap_or("unknown")
+    )
 }
 
-async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
-    state.db.get_all_users(0, 1).await?;
+async fn handle_liveness_probe(
+    app_state: Option<Extension<Arc<AppState>>>,
+    llm_state: Option<Extension<Arc<LlmState>>>,
+) -> Result<String> {
+    if let Some(state) = app_state {
+        state.db.get_all_users(0, 1).await?;
+    }
+
+    if let Some(_llm_state) = llm_state {}
+
     Ok("ok".to_string())
 }