Procfile 🔗
@@ -1,3 +1,3 @@
-collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve
+collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
livekit: livekit-server --dev
blob_store: ./script/run-local-minio
Max Brunsfeld , Marshall , and Jason created
This is just a refactor that we're landing ahead of any functional
changes to make sure we haven't broken anything.
Release Notes:
- N/A
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Jason <jason@zed.dev>
Procfile | 2
crates/collab/src/api.rs | 9 -
crates/collab/src/lib.rs | 26 +++++
crates/collab/src/llm.rs | 16 +++
crates/collab/src/main.rs | 178 +++++++++++++++++++++++-----------------
5 files changed, 144 insertions(+), 87 deletions(-)
@@ -1,3 +1,3 @@
-collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve
+collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
livekit: livekit-server --dev
blob_store: ./script/run-local-minio
@@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
}
}
-pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
+pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/user", get(get_authenticated_user))
.route("/users/:id/access_tokens", post(create_access_token))
@@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
.merge(contributors::router())
.layer(
ServiceBuilder::new()
- .layer(Extension(state))
.layer(Extension(rpc_server))
.layer(middleware::from_fn(validate_api_token)),
)
@@ -152,12 +151,8 @@ struct CreateUserParams {
}
async fn get_rpc_server_snapshot(
- Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>,
+ Extension(rpc_server): Extension<Arc<rpc::Server>>,
) -> Result<ErasedJson> {
- let Some(rpc_server) = rpc_server else {
- return Err(Error::Internal(anyhow!("rpc server is not available")));
- };
-
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
}
@@ -3,6 +3,7 @@ pub mod auth;
pub mod db;
pub mod env;
pub mod executor;
+pub mod llm;
mod rate_limiter;
pub mod rpc;
pub mod seed;
@@ -124,7 +125,7 @@ impl std::fmt::Display for Error {
impl std::error::Error for Error {}
-#[derive(Deserialize)]
+#[derive(Clone, Deserialize)]
pub struct Config {
pub http_port: u16,
pub database_url: String,
@@ -176,6 +177,29 @@ impl Config {
}
}
+/// The service mode that collab should run in.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum ServiceMode {
+ Api,
+ Collab,
+ Llm,
+ All,
+}
+
+impl ServiceMode {
+ pub fn is_collab(&self) -> bool {
+ matches!(self, Self::Collab | Self::All)
+ }
+
+ pub fn is_api(&self) -> bool {
+ matches!(self, Self::Api | Self::All)
+ }
+
+ pub fn is_llm(&self) -> bool {
+ matches!(self, Self::Llm | Self::All)
+ }
+}
+
pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
@@ -0,0 +1,16 @@
+use std::sync::Arc;
+
+use crate::{executor::Executor, Config, Result};
+
+pub struct LlmState {
+ pub config: Config,
+ pub executor: Executor,
+}
+
+impl LlmState {
+ pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
+ let this = Self { config, executor };
+
+ Ok(Arc::new(this))
+ }
+}
@@ -5,7 +5,7 @@ use axum::{
routing::get,
Extension, Router,
};
-use collab::api::billing::poll_stripe_events_periodically;
+use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
rpc::ResultExt, AppState, Config, RateLimiter, Result,
@@ -56,88 +56,99 @@ async fn main() -> Result<()> {
collab::seed::seed(&config, &db, true).await?;
}
Some("serve") => {
- let (is_api, is_collab) = if let Some(next) = args.next() {
- (next == "api", next == "collab")
- } else {
- (true, true)
+ let mode = match args.next().as_deref() {
+ Some("collab") => ServiceMode::Collab,
+ Some("api") => ServiceMode::Api,
+ Some("llm") => ServiceMode::Llm,
+ Some("all") => ServiceMode::All,
+ _ => {
+ return Err(anyhow!(
+ "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
+ ))?;
+ }
};
- if !is_api && !is_collab {
- Err(anyhow!(
- "usage: collab <version | migrate | seed | serve [api|collab]>"
- ))?;
- }
let config = envy::from_env::<Config>().expect("error loading config");
init_tracing(&config);
+ let mut app = Router::new()
+ .route("/", get(handle_root))
+ .route("/healthz", get(handle_liveness_probe))
+ .layer(Extension(mode));
- run_migrations(&config).await?;
-
- let state = AppState::new(config, Executor::Production).await?;
-
- let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
+ let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
.expect("failed to bind TCP listener");
- let rpc_server = if is_collab {
- let epoch = state
- .db
- .create_server(&state.config.zed_environment)
- .await?;
- let rpc_server = collab::rpc::Server::new(epoch, state.clone());
- rpc_server.start().await?;
-
- Some(rpc_server)
- } else {
- None
- };
+ let mut on_shutdown = None;
- if is_collab {
- state.db.purge_old_embeddings().await.trace_err();
- RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
- }
+ if mode.is_llm() {
+ let state = LlmState::new(config.clone(), Executor::Production).await?;
- if is_api {
- poll_stripe_events_periodically(state.clone());
- fetch_extensions_from_blob_store_periodically(state.clone());
+ app = app.layer(Extension(state.clone()));
}
- let mut app = collab::api::routes(rpc_server.clone(), state.clone());
- if let Some(rpc_server) = rpc_server.clone() {
- app = app.merge(collab::rpc::routes(rpc_server))
- }
- app = app
- .merge(
- Router::new()
- .route("/", get(handle_root))
- .route("/healthz", get(handle_liveness_probe))
- .merge(collab::api::extensions::router())
+ if mode.is_collab() || mode.is_api() {
+ run_migrations(&config).await?;
+
+ let state = AppState::new(config, Executor::Production).await?;
+
+ if mode.is_collab() {
+ state.db.purge_old_embeddings().await.trace_err();
+ RateLimiter::save_periodically(
+ state.rate_limiter.clone(),
+ state.executor.clone(),
+ );
+
+ let epoch = state
+ .db
+ .create_server(&state.config.zed_environment)
+ .await?;
+ let rpc_server = collab::rpc::Server::new(epoch, state.clone());
+ rpc_server.start().await?;
+
+ app = app
+ .merge(collab::api::routes(rpc_server.clone()))
+ .merge(collab::rpc::routes(rpc_server.clone()));
+
+ on_shutdown = Some(Box::new(move || rpc_server.teardown()));
+ }
+
+ if mode.is_api() {
+ poll_stripe_events_periodically(state.clone());
+ fetch_extensions_from_blob_store_periodically(state.clone());
+
+ app = app
.merge(collab::api::events::router())
- .layer(Extension(state.clone())),
- )
- .layer(
- TraceLayer::new_for_http()
- .make_span_with(|request: &Request<_>| {
- let matched_path = request
- .extensions()
- .get::<MatchedPath>()
- .map(MatchedPath::as_str);
-
- tracing::info_span!(
- "http_request",
- method = ?request.method(),
- matched_path,
- )
- })
- .on_response(
- |response: &Response<_>, latency: Duration, _: &tracing::Span| {
- let duration_ms = latency.as_micros() as f64 / 1000.;
- tracing::info!(
- duration_ms,
- status = response.status().as_u16(),
- "finished processing request"
- );
- },
- ),
- );
+ .merge(collab::api::extensions::router())
+ }
+
+ app = app.layer(Extension(state.clone()));
+ }
+
+ app = app.layer(
+ TraceLayer::new_for_http()
+ .make_span_with(|request: &Request<_>| {
+ let matched_path = request
+ .extensions()
+ .get::<MatchedPath>()
+ .map(MatchedPath::as_str);
+
+ tracing::info_span!(
+ "http_request",
+ method = ?request.method(),
+ matched_path,
+ )
+ })
+ .on_response(
+ |response: &Response<_>, latency: Duration, _: &tracing::Span| {
+ let duration_ms = latency.as_micros() as f64 / 1000.;
+ tracing::info!(
+ duration_ms,
+ status = response.status().as_u16(),
+ "finished processing request"
+ );
+ },
+ ),
+ );
#[cfg(unix)]
let signal = async move {
@@ -174,8 +185,8 @@ async fn main() -> Result<()> {
signal.await;
tracing::info!("Received interrupt signal");
- if let Some(rpc_server) = rpc_server {
- rpc_server.teardown();
+ if let Some(on_shutdown) = on_shutdown {
+ on_shutdown();
}
})
.await
@@ -183,7 +194,7 @@ async fn main() -> Result<()> {
}
_ => {
Err(anyhow!(
- "usage: collab <version | migrate | seed | serve [api|collab]>"
+ "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
))?;
}
}
@@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
return Ok(());
}
-async fn handle_root() -> String {
- format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
+async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
+ format!(
+ "collab {mode:?} v{VERSION} ({})",
+ REVISION.unwrap_or("unknown")
+ )
}
-async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
- state.db.get_all_users(0, 1).await?;
+async fn handle_liveness_probe(
+ app_state: Option<Extension<Arc<AppState>>>,
+ llm_state: Option<Extension<Arc<LlmState>>>,
+) -> Result<String> {
+ if let Some(state) = app_state {
+ state.db.get_all_users(0, 1).await?;
+ }
+
+ if let Some(_llm_state) = llm_state {}
+
Ok("ok".to_string())
}