main.rs

  1use anyhow::anyhow;
  2use axum::{
  3    extract::MatchedPath,
  4    http::{Request, Response},
  5    routing::get,
  6    Extension, Router,
  7};
  8use collab::llm::{db::LlmDatabase, log_usage_periodically};
  9use collab::migrations::run_database_migrations;
 10use collab::user_backfiller::spawn_user_backfiller;
 11use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
 12use collab::{
 13    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
 14    rpc::ResultExt, AppState, Config, RateLimiter, Result,
 15};
 16use db::Database;
 17use std::{
 18    env::args,
 19    net::{SocketAddr, TcpListener},
 20    path::Path,
 21    sync::Arc,
 22    time::Duration,
 23};
 24#[cfg(unix)]
 25use tokio::signal::unix::SignalKind;
 26use tower_http::trace::TraceLayer;
 27use tracing_subscriber::{
 28    filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
 29};
 30use util::ResultExt as _;
 31
 32const VERSION: &str = env!("CARGO_PKG_VERSION");
 33const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 34
 35#[tokio::main]
 36async fn main() -> Result<()> {
 37    if let Err(error) = env::load_dotenv() {
 38        eprintln!(
 39            "error loading .env.toml (this is expected in production): {}",
 40            error
 41        );
 42    }
 43
 44    let mut args = args().skip(1);
 45    match args.next().as_deref() {
 46        Some("version") => {
 47            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 48        }
 49        Some("migrate") => {
 50            let config = envy::from_env::<Config>().expect("error loading config");
 51            setup_app_database(&config).await?;
 52        }
 53        Some("seed") => {
 54            let config = envy::from_env::<Config>().expect("error loading config");
 55            let db_options = db::ConnectOptions::new(config.database_url.clone());
 56
 57            let mut db = Database::new(db_options, Executor::Production).await?;
 58            db.initialize_notification_kinds().await?;
 59
 60            collab::seed::seed(&config, &db, false).await?;
 61
 62            if let Some(llm_database_url) = config.llm_database_url.clone() {
 63                let db_options = db::ConnectOptions::new(llm_database_url);
 64                let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
 65                db.initialize().await?;
 66                collab::llm::db::seed_database(&config, &mut db, true).await?;
 67            }
 68        }
 69        Some("serve") => {
 70            let mode = match args.next().as_deref() {
 71                Some("collab") => ServiceMode::Collab,
 72                Some("api") => ServiceMode::Api,
 73                Some("llm") => ServiceMode::Llm,
 74                Some("all") => ServiceMode::All,
 75                _ => {
 76                    return Err(anyhow!(
 77                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
 78                    ))?;
 79                }
 80            };
 81
 82            let config = envy::from_env::<Config>().expect("error loading config");
 83            init_tracing(&config);
 84            let mut app = Router::new()
 85                .route("/", get(handle_root))
 86                .route("/healthz", get(handle_liveness_probe))
 87                .layer(Extension(mode));
 88
 89            let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
 90                .expect("failed to bind TCP listener");
 91
 92            let mut on_shutdown = None;
 93
 94            if mode.is_llm() {
 95                setup_llm_database(&config).await?;
 96
 97                let state = LlmState::new(config.clone(), Executor::Production).await?;
 98
 99                log_usage_periodically(state.clone());
100
101                app = app
102                    .merge(collab::llm::routes())
103                    .layer(Extension(state.clone()));
104            }
105
106            if mode.is_collab() || mode.is_api() {
107                setup_app_database(&config).await?;
108
109                let state = AppState::new(config, Executor::Production).await?;
110
111                if mode.is_collab() {
112                    state.db.purge_old_embeddings().await.trace_err();
113                    RateLimiter::save_periodically(
114                        state.rate_limiter.clone(),
115                        state.executor.clone(),
116                    );
117
118                    let epoch = state
119                        .db
120                        .create_server(&state.config.zed_environment)
121                        .await?;
122                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
123                    rpc_server.start().await?;
124
125                    app = app
126                        .merge(collab::api::routes(rpc_server.clone()))
127                        .merge(collab::rpc::routes(rpc_server.clone()));
128
129                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
130                }
131
132                if mode.is_api() {
133                    poll_stripe_events_periodically(state.clone());
134                    fetch_extensions_from_blob_store_periodically(state.clone());
135                    spawn_user_backfiller(state.clone());
136
137                    app = app
138                        .merge(collab::api::events::router())
139                        .merge(collab::api::extensions::router())
140                }
141
142                app = app.layer(Extension(state.clone()));
143            }
144
145            app = app.layer(
146                TraceLayer::new_for_http()
147                    .make_span_with(|request: &Request<_>| {
148                        let matched_path = request
149                            .extensions()
150                            .get::<MatchedPath>()
151                            .map(MatchedPath::as_str);
152
153                        tracing::info_span!(
154                            "http_request",
155                            method = ?request.method(),
156                            matched_path,
157                            user_id = tracing::field::Empty,
158                            login = tracing::field::Empty,
159                            authn.jti = tracing::field::Empty,
160                            is_staff = tracing::field::Empty
161                        )
162                    })
163                    .on_response(
164                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
165                            let duration_ms = latency.as_micros() as f64 / 1000.;
166                            tracing::info!(
167                                duration_ms,
168                                status = response.status().as_u16(),
169                                "finished processing request"
170                            );
171                        },
172                    ),
173            );
174
175            #[cfg(unix)]
176            let signal = async move {
177                let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
178                    .expect("failed to listen for interrupt signal");
179                let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
180                    .expect("failed to listen for interrupt signal");
181                let sigterm = sigterm.recv();
182                let sigint = sigint.recv();
183                futures::pin_mut!(sigterm, sigint);
184                futures::future::select(sigterm, sigint).await;
185            };
186
187            #[cfg(windows)]
188            let signal = async move {
189                // todo(windows):
190                // `ctrl_close` does not work well, because tokio's signal handler always returns soon,
191                // but system terminates the application soon after returning CTRL+CLOSE handler.
192                // So we should implement blocking handler to treat CTRL+CLOSE signal.
193                let mut ctrl_break = tokio::signal::windows::ctrl_break()
194                    .expect("failed to listen for interrupt signal");
195                let mut ctrl_c = tokio::signal::windows::ctrl_c()
196                    .expect("failed to listen for interrupt signal");
197                let ctrl_break = ctrl_break.recv();
198                let ctrl_c = ctrl_c.recv();
199                futures::pin_mut!(ctrl_break, ctrl_c);
200                futures::future::select(ctrl_break, ctrl_c).await;
201            };
202
203            axum::Server::from_tcp(listener)
204                .map_err(|e| anyhow!(e))?
205                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
206                .with_graceful_shutdown(async move {
207                    signal.await;
208                    tracing::info!("Received interrupt signal");
209
210                    if let Some(on_shutdown) = on_shutdown {
211                        on_shutdown();
212                    }
213                })
214                .await
215                .map_err(|e| anyhow!(e))?;
216        }
217        _ => {
218            Err(anyhow!(
219                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
220            ))?;
221        }
222    }
223    Ok(())
224}
225
226async fn setup_app_database(config: &Config) -> Result<()> {
227    let db_options = db::ConnectOptions::new(config.database_url.clone());
228    let mut db = Database::new(db_options, Executor::Production).await?;
229
230    let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
231        #[cfg(feature = "sqlite")]
232        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
233        #[cfg(not(feature = "sqlite"))]
234        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
235
236        Path::new(default_migrations)
237    });
238
239    let migrations = run_database_migrations(db.options(), migrations_path).await?;
240    for (migration, duration) in migrations {
241        log::info!(
242            "Migrated {} {} {:?}",
243            migration.version,
244            migration.description,
245            duration
246        );
247    }
248
249    db.initialize_notification_kinds().await?;
250
251    if config.seed_path.is_some() {
252        collab::seed::seed(&config, &db, false).await?;
253    }
254
255    Ok(())
256}
257
258async fn setup_llm_database(config: &Config) -> Result<()> {
259    let database_url = config
260        .llm_database_url
261        .as_ref()
262        .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
263
264    let db_options = db::ConnectOptions::new(database_url.clone());
265    let db = LlmDatabase::new(db_options, Executor::Production).await?;
266
267    let migrations_path = config
268        .llm_database_migrations_path
269        .as_deref()
270        .unwrap_or_else(|| {
271            #[cfg(feature = "sqlite")]
272            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
273            #[cfg(not(feature = "sqlite"))]
274            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
275
276            Path::new(default_migrations)
277        });
278
279    let migrations = run_database_migrations(db.options(), migrations_path).await?;
280    for (migration, duration) in migrations {
281        log::info!(
282            "Migrated {} {} {:?}",
283            migration.version,
284            migration.description,
285            duration
286        );
287    }
288
289    Ok(())
290}
291
292async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
293    format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
294}
295
296async fn handle_liveness_probe(
297    app_state: Option<Extension<Arc<AppState>>>,
298    llm_state: Option<Extension<Arc<LlmState>>>,
299) -> Result<String> {
300    if let Some(state) = app_state {
301        state.db.get_all_users(0, 1).await?;
302    }
303
304    if let Some(llm_state) = llm_state {
305        llm_state.db.list_providers().await?;
306    }
307
308    Ok("ok".to_string())
309}
310
311pub fn init_tracing(config: &Config) -> Option<()> {
312    use std::str::FromStr;
313    use tracing_subscriber::layer::SubscriberExt;
314
315    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
316
317    tracing_subscriber::registry()
318        .with(if config.log_json.unwrap_or(false) {
319            Box::new(
320                tracing_subscriber::fmt::layer()
321                    .fmt_fields(JsonFields::default())
322                    .event_format(
323                        tracing_subscriber::fmt::format()
324                            .json()
325                            .flatten_event(true)
326                            .with_span_list(false),
327                    )
328                    .with_filter(filter),
329            ) as Box<dyn Layer<_> + Send + Sync>
330        } else {
331            Box::new(
332                tracing_subscriber::fmt::layer()
333                    .event_format(tracing_subscriber::fmt::format().pretty())
334                    .with_filter(filter),
335            )
336        })
337        .init();
338
339    None
340}