main.rs

  1use anyhow::anyhow;
  2use axum::{extract::MatchedPath, routing::get, Extension, Router};
  3use collab::{
  4    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
  5    Config, MigrateConfig, Result,
  6};
  7use db::Database;
  8use hyper::Request;
  9use std::{
 10    env::args,
 11    net::{SocketAddr, TcpListener},
 12    path::Path,
 13    sync::Arc,
 14};
 15use tokio::signal::unix::SignalKind;
 16use tower_http::trace::{self, TraceLayer};
 17use tracing::Level;
 18use tracing_log::LogTracer;
 19use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 20use util::ResultExt;
 21
 22const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 23const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 24
 25#[tokio::main]
 26async fn main() -> Result<()> {
 27    if let Err(error) = env::load_dotenv() {
 28        eprintln!(
 29            "error loading .env.toml (this is expected in production): {}",
 30            error
 31        );
 32    }
 33
 34    let mut args = args().skip(1);
 35    match args.next().as_deref() {
 36        Some("version") => {
 37            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 38        }
 39        Some("migrate") => {
 40            run_migrations().await?;
 41        }
 42        Some("serve") => {
 43            let is_api_only = args.next().is_some_and(|arg| arg == "api");
 44
 45            let config = envy::from_env::<Config>().expect("error loading config");
 46            init_tracing(&config);
 47
 48            run_migrations().await?;
 49
 50            let state = AppState::new(config).await?;
 51
 52            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 53                .expect("failed to bind TCP listener");
 54
 55            let rpc_server = if !is_api_only {
 56                let epoch = state
 57                    .db
 58                    .create_server(&state.config.zed_environment)
 59                    .await?;
 60                let rpc_server =
 61                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 62                rpc_server.start().await?;
 63
 64                Some(rpc_server)
 65            } else {
 66                None
 67            };
 68
 69            // TODO: Once we move the extensions endpoints to run inside `api` service, move the background task as well.
 70            if !is_api_only {
 71                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
 72            }
 73
 74            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
 75            if let Some(rpc_server) = rpc_server.clone() {
 76                app = app.merge(collab::rpc::routes(rpc_server))
 77            }
 78            app = app
 79                .merge(
 80                    Router::new()
 81                        .route("/", get(handle_root))
 82                        .route("/healthz", get(handle_liveness_probe))
 83                        .merge(collab::api::events::router())
 84                        .layer(Extension(state.clone())),
 85                )
 86                .layer(
 87                    TraceLayer::new_for_http()
 88                        .make_span_with(|request: &Request<_>| {
 89                            let matched_path = request
 90                                .extensions()
 91                                .get::<MatchedPath>()
 92                                .map(MatchedPath::as_str);
 93
 94                            tracing::info_span!(
 95                                "http_request",
 96                                method = ?request.method(),
 97                                matched_path,
 98                            )
 99                        })
100                        .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
101                );
102
103            axum::Server::from_tcp(listener)?
104                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
105                .with_graceful_shutdown(async move {
106                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
107                        .expect("failed to listen for interrupt signal");
108                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
109                        .expect("failed to listen for interrupt signal");
110                    let sigterm = sigterm.recv();
111                    let sigint = sigint.recv();
112                    futures::pin_mut!(sigterm, sigint);
113                    futures::future::select(sigterm, sigint).await;
114                    tracing::info!("Received interrupt signal");
115
116                    if let Some(rpc_server) = rpc_server {
117                        rpc_server.teardown();
118                    }
119                })
120                .await?;
121        }
122        _ => {
123            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
124        }
125    }
126    Ok(())
127}
128
129async fn run_migrations() -> Result<()> {
130    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
131    let db_options = db::ConnectOptions::new(config.database_url.clone());
132    let db = Database::new(db_options, Executor::Production).await?;
133
134    let migrations_path = config
135        .migrations_path
136        .as_deref()
137        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
138
139    let migrations = db.migrate(&migrations_path, false).await?;
140    for (migration, duration) in migrations {
141        log::info!(
142            "Migrated {} {} {:?}",
143            migration.version,
144            migration.description,
145            duration
146        );
147    }
148
149    return Ok(());
150}
151
152async fn handle_root() -> String {
153    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
154}
155
156async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
157    state.db.get_all_users(0, 1).await?;
158    Ok("ok".to_string())
159}
160
161pub fn init_tracing(config: &Config) -> Option<()> {
162    use std::str::FromStr;
163    use tracing_subscriber::layer::SubscriberExt;
164    let rust_log = config.rust_log.clone()?;
165
166    LogTracer::init().log_err()?;
167
168    let subscriber = tracing_subscriber::Registry::default()
169        .with(if config.log_json.unwrap_or(false) {
170            Box::new(
171                tracing_subscriber::fmt::layer()
172                    .fmt_fields(JsonFields::default())
173                    .event_format(
174                        tracing_subscriber::fmt::format()
175                            .json()
176                            .flatten_event(true)
177                            .with_span_list(true),
178                    ),
179            ) as Box<dyn Layer<_> + Send + Sync>
180        } else {
181            Box::new(
182                tracing_subscriber::fmt::layer()
183                    .event_format(tracing_subscriber::fmt::format().pretty()),
184            )
185        })
186        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
187
188    tracing::subscriber::set_global_default(subscriber).unwrap();
189
190    None
191}