main.rs

  1use anyhow::anyhow;
  2use axum::headers::HeaderMapExt;
  3use axum::{
  4    Extension, Router,
  5    extract::MatchedPath,
  6    http::{Request, Response},
  7    routing::get,
  8};
  9
 10use collab::api::CloudflareIpCountryHeader;
 11use collab::api::billing::sync_llm_usage_with_stripe_periodically;
 12use collab::llm::{db::LlmDatabase, log_usage_periodically};
 13use collab::migrations::run_database_migrations;
 14use collab::user_backfiller::spawn_user_backfiller;
 15use collab::{
 16    AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
 17    env, executor::Executor, rpc::ResultExt,
 18};
 19use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
 20use db::Database;
 21use std::{
 22    env::args,
 23    net::{SocketAddr, TcpListener},
 24    path::Path,
 25    sync::Arc,
 26    time::Duration,
 27};
 28#[cfg(unix)]
 29use tokio::signal::unix::SignalKind;
 30use tower_http::trace::TraceLayer;
 31use tracing_subscriber::{
 32    Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt,
 33};
 34use util::{ResultExt as _, maybe};
 35
 36const VERSION: &str = env!("CARGO_PKG_VERSION");
 37const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 38
 39#[tokio::main]
 40async fn main() -> Result<()> {
 41    if let Err(error) = env::load_dotenv() {
 42        eprintln!(
 43            "error loading .env.toml (this is expected in production): {}",
 44            error
 45        );
 46    }
 47
 48    let mut args = args().skip(1);
 49    match args.next().as_deref() {
 50        Some("version") => {
 51            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 52        }
 53        Some("migrate") => {
 54            let config = envy::from_env::<Config>().expect("error loading config");
 55            setup_app_database(&config).await?;
 56        }
 57        Some("seed") => {
 58            let config = envy::from_env::<Config>().expect("error loading config");
 59            let db_options = db::ConnectOptions::new(config.database_url.clone());
 60
 61            let mut db = Database::new(db_options, Executor::Production).await?;
 62            db.initialize_notification_kinds().await?;
 63
 64            collab::seed::seed(&config, &db, false).await?;
 65
 66            if let Some(llm_database_url) = config.llm_database_url.clone() {
 67                let db_options = db::ConnectOptions::new(llm_database_url);
 68                let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
 69                db.initialize().await?;
 70                collab::llm::db::seed_database(&config, &mut db, true).await?;
 71            }
 72        }
 73        Some("serve") => {
 74            let mode = match args.next().as_deref() {
 75                Some("collab") => ServiceMode::Collab,
 76                Some("api") => ServiceMode::Api,
 77                Some("llm") => ServiceMode::Llm,
 78                Some("all") => ServiceMode::All,
 79                _ => {
 80                    return Err(anyhow!(
 81                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
 82                    ))?;
 83                }
 84            };
 85
 86            let config = envy::from_env::<Config>().expect("error loading config");
 87            init_tracing(&config);
 88            init_panic_hook();
 89
 90            let mut app = Router::new()
 91                .route("/", get(handle_root))
 92                .route("/healthz", get(handle_liveness_probe))
 93                .layer(Extension(mode));
 94
 95            let listener = TcpListener::bind(format!("0.0.0.0:{}", config.http_port))
 96                .expect("failed to bind TCP listener");
 97
 98            let mut on_shutdown = None;
 99
100            if mode.is_llm() {
101                setup_llm_database(&config).await?;
102
103                let state = LlmState::new(config.clone(), Executor::Production).await?;
104
105                log_usage_periodically(state.clone());
106
107                app = app
108                    .merge(collab::llm::routes())
109                    .layer(Extension(state.clone()));
110            }
111
112            if mode.is_collab() || mode.is_api() {
113                setup_app_database(&config).await?;
114
115                let state = AppState::new(config, Executor::Production).await?;
116
117                if let Some(stripe_billing) = state.stripe_billing.clone() {
118                    let executor = state.executor.clone();
119                    executor.spawn_detached(async move {
120                        stripe_billing.initialize().await.trace_err();
121                    });
122                }
123
124                if mode.is_collab() {
125                    state.db.purge_old_embeddings().await.trace_err();
126                    RateLimiter::save_periodically(
127                        state.rate_limiter.clone(),
128                        state.executor.clone(),
129                    );
130
131                    let epoch = state
132                        .db
133                        .create_server(&state.config.zed_environment)
134                        .await?;
135                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
136                    rpc_server.start().await?;
137
138                    poll_stripe_events_periodically(state.clone(), rpc_server.clone());
139
140                    app = app
141                        .merge(collab::api::routes(rpc_server.clone()))
142                        .merge(collab::rpc::routes(rpc_server.clone()));
143
144                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
145                }
146
147                if mode.is_api() {
148                    fetch_extensions_from_blob_store_periodically(state.clone());
149                    spawn_user_backfiller(state.clone());
150
151                    let llm_db = maybe!(async {
152                        let database_url = state
153                            .config
154                            .llm_database_url
155                            .as_ref()
156                            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
157                        let max_connections = state
158                            .config
159                            .llm_database_max_connections
160                            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
161
162                        let mut db_options = db::ConnectOptions::new(database_url);
163                        db_options.max_connections(max_connections);
164                        LlmDatabase::new(db_options, state.executor.clone()).await
165                    })
166                    .await
167                    .trace_err();
168
169                    if let Some(mut llm_db) = llm_db {
170                        llm_db.initialize().await?;
171                        sync_llm_usage_with_stripe_periodically(state.clone());
172                    }
173
174                    app = app
175                        .merge(collab::api::events::router())
176                        .merge(collab::api::extensions::router())
177                }
178
179                app = app.layer(Extension(state.clone()));
180            }
181
182            app = app.layer(
183                TraceLayer::new_for_http()
184                    .make_span_with(|request: &Request<_>| {
185                        let matched_path = request
186                            .extensions()
187                            .get::<MatchedPath>()
188                            .map(MatchedPath::as_str);
189
190                        let geoip_country_code = request
191                            .headers()
192                            .typed_get::<CloudflareIpCountryHeader>()
193                            .map(|header| header.to_string());
194
195                        tracing::info_span!(
196                            "http_request",
197                            method = ?request.method(),
198                            matched_path,
199                            geoip_country_code,
200                            user_id = tracing::field::Empty,
201                            login = tracing::field::Empty,
202                            authn.jti = tracing::field::Empty,
203                            is_staff = tracing::field::Empty
204                        )
205                    })
206                    .on_response(
207                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
208                            let duration_ms = latency.as_micros() as f64 / 1000.;
209                            tracing::info!(
210                                duration_ms,
211                                status = response.status().as_u16(),
212                                "finished processing request"
213                            );
214                        },
215                    ),
216            );
217
218            #[cfg(unix)]
219            let signal = async move {
220                let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
221                    .expect("failed to listen for interrupt signal");
222                let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
223                    .expect("failed to listen for interrupt signal");
224                let sigterm = sigterm.recv();
225                let sigint = sigint.recv();
226                futures::pin_mut!(sigterm, sigint);
227                futures::future::select(sigterm, sigint).await;
228            };
229
230            #[cfg(windows)]
231            let signal = async move {
232                // todo(windows):
233                // `ctrl_close` does not work well, because tokio's signal handler always returns soon,
234                // but system terminates the application soon after returning CTRL+CLOSE handler.
235                // So we should implement blocking handler to treat CTRL+CLOSE signal.
236                let mut ctrl_break = tokio::signal::windows::ctrl_break()
237                    .expect("failed to listen for interrupt signal");
238                let mut ctrl_c = tokio::signal::windows::ctrl_c()
239                    .expect("failed to listen for interrupt signal");
240                let ctrl_break = ctrl_break.recv();
241                let ctrl_c = ctrl_c.recv();
242                futures::pin_mut!(ctrl_break, ctrl_c);
243                futures::future::select(ctrl_break, ctrl_c).await;
244            };
245
246            axum::Server::from_tcp(listener)
247                .map_err(|e| anyhow!(e))?
248                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
249                .with_graceful_shutdown(async move {
250                    signal.await;
251                    tracing::info!("Received interrupt signal");
252
253                    if let Some(on_shutdown) = on_shutdown {
254                        on_shutdown();
255                    }
256                })
257                .await
258                .map_err(|e| anyhow!(e))?;
259        }
260        _ => {
261            Err(anyhow!(
262                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
263            ))?;
264        }
265    }
266    Ok(())
267}
268
269async fn setup_app_database(config: &Config) -> Result<()> {
270    let db_options = db::ConnectOptions::new(config.database_url.clone());
271    let mut db = Database::new(db_options, Executor::Production).await?;
272
273    let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
274        #[cfg(feature = "sqlite")]
275        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
276        #[cfg(not(feature = "sqlite"))]
277        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
278
279        Path::new(default_migrations)
280    });
281
282    let migrations = run_database_migrations(db.options(), migrations_path).await?;
283    for (migration, duration) in migrations {
284        log::info!(
285            "Migrated {} {} {:?}",
286            migration.version,
287            migration.description,
288            duration
289        );
290    }
291
292    db.initialize_notification_kinds().await?;
293
294    if config.seed_path.is_some() {
295        collab::seed::seed(config, &db, false).await?;
296    }
297
298    Ok(())
299}
300
301async fn setup_llm_database(config: &Config) -> Result<()> {
302    let database_url = config
303        .llm_database_url
304        .as_ref()
305        .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
306
307    let db_options = db::ConnectOptions::new(database_url.clone());
308    let db = LlmDatabase::new(db_options, Executor::Production).await?;
309
310    let migrations_path = config
311        .llm_database_migrations_path
312        .as_deref()
313        .unwrap_or_else(|| {
314            #[cfg(feature = "sqlite")]
315            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
316            #[cfg(not(feature = "sqlite"))]
317            let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
318
319            Path::new(default_migrations)
320        });
321
322    let migrations = run_database_migrations(db.options(), migrations_path).await?;
323    for (migration, duration) in migrations {
324        log::info!(
325            "Migrated {} {} {:?}",
326            migration.version,
327            migration.description,
328            duration
329        );
330    }
331
332    Ok(())
333}
334
335async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
336    format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
337}
338
339async fn handle_liveness_probe(
340    app_state: Option<Extension<Arc<AppState>>>,
341    llm_state: Option<Extension<Arc<LlmState>>>,
342) -> Result<String> {
343    if let Some(state) = app_state {
344        state.db.get_all_users(0, 1).await?;
345    }
346
347    if let Some(llm_state) = llm_state {
348        llm_state.db.list_providers().await?;
349    }
350
351    Ok("ok".to_string())
352}
353
354pub fn init_tracing(config: &Config) -> Option<()> {
355    use std::str::FromStr;
356    use tracing_subscriber::layer::SubscriberExt;
357
358    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
359
360    tracing_subscriber::registry()
361        .with(if config.log_json.unwrap_or(false) {
362            Box::new(
363                tracing_subscriber::fmt::layer()
364                    .fmt_fields(JsonFields::default())
365                    .event_format(
366                        tracing_subscriber::fmt::format()
367                            .json()
368                            .flatten_event(true)
369                            .with_span_list(false),
370                    )
371                    .with_filter(filter),
372            ) as Box<dyn Layer<_> + Send + Sync>
373        } else {
374            Box::new(
375                tracing_subscriber::fmt::layer()
376                    .event_format(tracing_subscriber::fmt::format().pretty())
377                    .with_filter(filter),
378            )
379        })
380        .init();
381
382    None
383}
384
385fn init_panic_hook() {
386    std::panic::set_hook(Box::new(move |panic_info| {
387        let panic_message = match panic_info.payload().downcast_ref::<&'static str>() {
388            Some(message) => *message,
389            None => match panic_info.payload().downcast_ref::<String>() {
390                Some(message) => message.as_str(),
391                None => "Box<Any>",
392            },
393        };
394        let backtrace = std::backtrace::Backtrace::force_capture();
395        let location = panic_info
396            .location()
397            .map(|loc| format!("{}:{}", loc.file(), loc.line()));
398        tracing::error!(panic = true, ?location, %panic_message, %backtrace, "Server Panic");
399    }));
400}