main.rs

  1use anyhow::anyhow;
  2use axum::headers::HeaderMapExt;
  3use axum::{
  4    Extension, Router,
  5    extract::MatchedPath,
  6    http::{Request, Response},
  7    routing::get,
  8};
  9
 10use collab::api::CloudflareIpCountryHeader;
 11use collab::{
 12    AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
 13    executor::Executor,
 14};
 15use collab::{REVISION, ServiceMode, VERSION};
 16use db::Database;
 17use std::{
 18    env::args,
 19    net::{SocketAddr, TcpListener},
 20    sync::Arc,
 21    time::Duration,
 22};
 23#[cfg(unix)]
 24use tokio::signal::unix::SignalKind;
 25use tower_http::trace::TraceLayer;
 26use tracing_subscriber::{
 27    Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt,
 28};
 29use util::ResultExt as _;
 30
 31#[expect(clippy::result_large_err)]
 32#[tokio::main]
 33async fn main() -> Result<()> {
 34    if let Err(error) = env::load_dotenv() {
 35        eprintln!(
 36            "error loading .env.toml (this is expected in production): {}",
 37            error
 38        );
 39    }
 40
 41    let mut args = args().skip(1);
 42    match args.next().as_deref() {
 43        Some("version") => {
 44            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 45        }
 46        Some("seed") => {
 47            let config = envy::from_env::<Config>().expect("error loading config");
 48            let db_options = db::ConnectOptions::new(config.database_url.clone());
 49
 50            let mut db = Database::new(db_options).await?;
 51            db.initialize_notification_kinds().await?;
 52
 53            collab::seed::seed(&config, &db, false).await?;
 54        }
 55        Some("serve") => {
 56            let mode = match args.next().as_deref() {
 57                Some("collab") => ServiceMode::Collab,
 58                Some("api") => ServiceMode::Api,
 59                Some("all") => ServiceMode::All,
 60                _ => {
 61                    return Err(anyhow!(
 62                        "usage: collab <version | seed | serve <api|collab|all>>"
 63                    ))?;
 64                }
 65            };
 66
 67            let config = envy::from_env::<Config>().expect("error loading config");
 68            init_tracing(&config);
 69            init_panic_hook();
 70
 71            let mut app = Router::new()
 72                .route("/", get(handle_root))
 73                .route("/healthz", get(handle_liveness_probe))
 74                .layer(Extension(mode));
 75
 76            let listener = TcpListener::bind(format!("0.0.0.0:{}", config.http_port))
 77                .expect("failed to bind TCP listener");
 78
 79            let mut on_shutdown = None;
 80
 81            if mode.is_collab() || mode.is_api() {
 82                setup_app_database(&config).await?;
 83
 84                let state = AppState::new(config, Executor::Production).await?;
 85
 86                if mode.is_collab() {
 87                    let epoch = state
 88                        .db
 89                        .create_server(&state.config.zed_environment)
 90                        .await?;
 91                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
 92                    rpc_server.start().await?;
 93
 94                    app = app.merge(collab::rpc::routes(rpc_server.clone()));
 95
 96                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
 97                }
 98
 99                if mode.is_api() {
100                    fetch_extensions_from_blob_store_periodically(state.clone());
101
102                    app = app
103                        .merge(collab::api::events::router())
104                        .merge(collab::api::extensions::router())
105                }
106
107                app = app.layer(Extension(state.clone()));
108            }
109
110            app = app.layer(
111                TraceLayer::new_for_http()
112                    .make_span_with(|request: &Request<_>| {
113                        let matched_path = request
114                            .extensions()
115                            .get::<MatchedPath>()
116                            .map(MatchedPath::as_str);
117
118                        let geoip_country_code = request
119                            .headers()
120                            .typed_get::<CloudflareIpCountryHeader>()
121                            .map(|header| header.to_string());
122
123                        tracing::info_span!(
124                            "http_request",
125                            method = ?request.method(),
126                            matched_path,
127                            geoip_country_code,
128                            user_id = tracing::field::Empty,
129                            login = tracing::field::Empty,
130                            authn.jti = tracing::field::Empty,
131                            is_staff = tracing::field::Empty
132                        )
133                    })
134                    .on_response(
135                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
136                            let duration_ms = latency.as_micros() as f64 / 1000.;
137                            tracing::info!(
138                                duration_ms,
139                                status = response.status().as_u16(),
140                                "finished processing request"
141                            );
142                        },
143                    ),
144            );
145
146            #[cfg(unix)]
147            let signal = async move {
148                let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
149                    .expect("failed to listen for interrupt signal");
150                let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
151                    .expect("failed to listen for interrupt signal");
152                let sigterm = sigterm.recv();
153                let sigint = sigint.recv();
154                futures::pin_mut!(sigterm, sigint);
155                futures::future::select(sigterm, sigint).await;
156            };
157
158            #[cfg(windows)]
159            let signal = async move {
160                // todo(windows):
161                // `ctrl_close` does not work well, because tokio's signal handler always returns soon,
162                // but system terminates the application soon after returning CTRL+CLOSE handler.
163                // So we should implement blocking handler to treat CTRL+CLOSE signal.
164                let mut ctrl_break = tokio::signal::windows::ctrl_break()
165                    .expect("failed to listen for interrupt signal");
166                let mut ctrl_c = tokio::signal::windows::ctrl_c()
167                    .expect("failed to listen for interrupt signal");
168                let ctrl_break = ctrl_break.recv();
169                let ctrl_c = ctrl_c.recv();
170                futures::pin_mut!(ctrl_break, ctrl_c);
171                futures::future::select(ctrl_break, ctrl_c).await;
172            };
173
174            axum::Server::from_tcp(listener)
175                .map_err(|e| anyhow!(e))?
176                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
177                .with_graceful_shutdown(async move {
178                    signal.await;
179                    tracing::info!("Received interrupt signal");
180
181                    if let Some(on_shutdown) = on_shutdown {
182                        on_shutdown();
183                    }
184                })
185                .await
186                .map_err(|e| anyhow!(e))?;
187        }
188        _ => {
189            Err(anyhow!(
190                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
191            ))?;
192        }
193    }
194    Ok(())
195}
196
197async fn setup_app_database(config: &Config) -> Result<()> {
198    let db_options = db::ConnectOptions::new(config.database_url.clone());
199    let mut db = Database::new(db_options).await?;
200
201    db.initialize_notification_kinds().await?;
202
203    if config.seed_path.is_some() {
204        collab::seed::seed(config, &db, false).await?;
205    }
206
207    Ok(())
208}
209
210async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
211    format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
212}
213
214async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
215    if let Some(state) = app_state {
216        state.db.get_all_users(0, 1).await?;
217    }
218
219    Ok("ok".to_string())
220}
221
222pub fn init_tracing(config: &Config) -> Option<()> {
223    use std::str::FromStr;
224    use tracing_subscriber::layer::SubscriberExt;
225
226    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
227
228    tracing_subscriber::registry()
229        .with(if config.log_json.unwrap_or(false) {
230            Box::new(
231                tracing_subscriber::fmt::layer()
232                    .fmt_fields(JsonFields::default())
233                    .event_format(
234                        tracing_subscriber::fmt::format()
235                            .json()
236                            .flatten_event(true)
237                            .with_span_list(false),
238                    )
239                    .with_filter(filter),
240            ) as Box<dyn Layer<_> + Send + Sync>
241        } else {
242            Box::new(
243                tracing_subscriber::fmt::layer()
244                    .event_format(tracing_subscriber::fmt::format().pretty())
245                    .with_filter(filter),
246            )
247        })
248        .init();
249
250    None
251}
252
253fn init_panic_hook() {
254    std::panic::set_hook(Box::new(move |panic_info| {
255        let panic_message = match panic_info.payload().downcast_ref::<&'static str>() {
256            Some(message) => *message,
257            None => match panic_info.payload().downcast_ref::<String>() {
258                Some(message) => message.as_str(),
259                None => "Box<Any>",
260            },
261        };
262        let backtrace = std::backtrace::Backtrace::force_capture();
263        let location = panic_info
264            .location()
265            .map(|loc| format!("{}:{}", loc.file(), loc.line()));
266        tracing::error!(panic = true, ?location, %panic_message, %backtrace, "Server Panic");
267    }));
268}