main.rs

  1use anyhow::anyhow;
  2use axum::{routing::get, Extension, Router};
  3use collab::{db, env, executor::Executor, AppState, Config, MigrateConfig, Result};
  4use db::Database;
  5use std::{
  6    env::args,
  7    net::{SocketAddr, TcpListener},
  8    path::Path,
  9    sync::Arc,
 10};
 11use tokio::signal::unix::SignalKind;
 12use tracing_log::LogTracer;
 13use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 14use util::ResultExt;
 15
 16const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 17
 18#[tokio::main]
 19async fn main() -> Result<()> {
 20    if let Err(error) = env::load_dotenv() {
 21        eprintln!(
 22            "error loading .env.toml (this is expected in production): {}",
 23            error
 24        );
 25    }
 26
 27    match args().skip(1).next().as_deref() {
 28        Some("version") => {
 29            println!("collab v{VERSION}");
 30        }
 31        Some("serve") => {
 32            let config = envy::from_env::<Config>().expect("error loading config");
 33            init_tracing(&config);
 34
 35            run_migrations().await?;
 36
 37            let state = AppState::new(config).await?;
 38
 39            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 40                .expect("failed to bind TCP listener");
 41
 42            let epoch = state
 43                .db
 44                .create_server(&state.config.zed_environment)
 45                .await?;
 46            let rpc_server = collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 47            rpc_server.start().await?;
 48
 49            let app = collab::api::routes(rpc_server.clone(), state.clone())
 50                .merge(collab::rpc::routes(rpc_server.clone()))
 51                .merge(
 52                    Router::new()
 53                        .route("/", get(handle_root))
 54                        .route("/healthz", get(handle_liveness_probe))
 55                        .layer(Extension(state.clone())),
 56                );
 57
 58            axum::Server::from_tcp(listener)?
 59                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
 60                .with_graceful_shutdown(async move {
 61                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
 62                        .expect("failed to listen for interrupt signal");
 63                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
 64                        .expect("failed to listen for interrupt signal");
 65                    let sigterm = sigterm.recv();
 66                    let sigint = sigint.recv();
 67                    futures::pin_mut!(sigterm, sigint);
 68                    futures::future::select(sigterm, sigint).await;
 69                    tracing::info!("Received interrupt signal");
 70                    rpc_server.teardown();
 71                })
 72                .await?;
 73        }
 74        _ => {
 75            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
 76        }
 77    }
 78    Ok(())
 79}
 80
 81async fn run_migrations() -> Result<()> {
 82    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 83    let db_options = db::ConnectOptions::new(config.database_url.clone());
 84    let db = Database::new(db_options, Executor::Production).await?;
 85
 86    let migrations_path = config
 87        .migrations_path
 88        .as_deref()
 89        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
 90
 91    let migrations = db.migrate(&migrations_path, false).await?;
 92    for (migration, duration) in migrations {
 93        log::info!(
 94            "Migrated {} {} {:?}",
 95            migration.version,
 96            migration.description,
 97            duration
 98        );
 99    }
100
101    return Ok(());
102}
103
104async fn handle_root() -> String {
105    format!("collab v{VERSION}")
106}
107
108async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
109    state.db.get_all_users(0, 1).await?;
110    Ok("ok".to_string())
111}
112
113pub fn init_tracing(config: &Config) -> Option<()> {
114    use std::str::FromStr;
115    use tracing_subscriber::layer::SubscriberExt;
116    let rust_log = config.rust_log.clone()?;
117
118    LogTracer::init().log_err()?;
119
120    let subscriber = tracing_subscriber::Registry::default()
121        .with(if config.log_json.unwrap_or(false) {
122            Box::new(
123                tracing_subscriber::fmt::layer()
124                    .fmt_fields(JsonFields::default())
125                    .event_format(
126                        tracing_subscriber::fmt::format()
127                            .json()
128                            .flatten_event(true)
129                            .with_span_list(true),
130                    ),
131            ) as Box<dyn Layer<_> + Send + Sync>
132        } else {
133            Box::new(
134                tracing_subscriber::fmt::layer()
135                    .event_format(tracing_subscriber::fmt::format().pretty()),
136            )
137        })
138        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
139
140    tracing::subscriber::set_global_default(subscriber).unwrap();
141
142    None
143}