main.rs

  1use anyhow::anyhow;
  2use axum::{extract::MatchedPath, routing::get, Extension, Router};
  3use collab::{
  4    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
  5    Config, MigrateConfig, Result,
  6};
  7use db::Database;
  8use hyper::Request;
  9use std::{
 10    env::args,
 11    net::{SocketAddr, TcpListener},
 12    path::Path,
 13    sync::Arc,
 14};
 15#[cfg(unix)]
 16use tokio::signal::unix::SignalKind;
 17use tower_http::trace::{self, TraceLayer};
 18use tracing::Level;
 19use tracing_log::LogTracer;
 20use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 21use util::ResultExt;
 22
 23const VERSION: &str = env!("CARGO_PKG_VERSION");
 24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 25
 26#[tokio::main]
 27async fn main() -> Result<()> {
 28    if let Err(error) = env::load_dotenv() {
 29        eprintln!(
 30            "error loading .env.toml (this is expected in production): {}",
 31            error
 32        );
 33    }
 34
 35    let mut args = args().skip(1);
 36    match args.next().as_deref() {
 37        Some("version") => {
 38            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 39        }
 40        Some("migrate") => {
 41            run_migrations().await?;
 42        }
 43        Some("serve") => {
 44            let (is_api, is_collab) = if let Some(next) = args.next() {
 45                (next == "api", next == "collab")
 46            } else {
 47                (true, true)
 48            };
 49            if !is_api && !is_collab {
 50                Err(anyhow!(
 51                    "usage: collab <version | migrate | serve [api|collab]>"
 52                ))?;
 53            }
 54
 55            let config = envy::from_env::<Config>().expect("error loading config");
 56            init_tracing(&config);
 57
 58            run_migrations().await?;
 59
 60            let state = AppState::new(config).await?;
 61
 62            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 63                .expect("failed to bind TCP listener");
 64
 65            let rpc_server = if is_collab {
 66                let epoch = state
 67                    .db
 68                    .create_server(&state.config.zed_environment)
 69                    .await?;
 70                let rpc_server =
 71                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 72                rpc_server.start().await?;
 73
 74                Some(rpc_server)
 75            } else {
 76                None
 77            };
 78
 79            if is_api {
 80                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
 81            }
 82
 83            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
 84            if let Some(rpc_server) = rpc_server.clone() {
 85                app = app.merge(collab::rpc::routes(rpc_server))
 86            }
 87            app = app
 88                .merge(
 89                    Router::new()
 90                        .route("/", get(handle_root))
 91                        .route("/healthz", get(handle_liveness_probe))
 92                        .merge(collab::api::extensions::router())
 93                        .merge(collab::api::events::router())
 94                        .layer(Extension(state.clone())),
 95                )
 96                .layer(
 97                    TraceLayer::new_for_http()
 98                        .make_span_with(|request: &Request<_>| {
 99                            let matched_path = request
100                                .extensions()
101                                .get::<MatchedPath>()
102                                .map(MatchedPath::as_str);
103
104                            tracing::info_span!(
105                                "http_request",
106                                method = ?request.method(),
107                                matched_path,
108                            )
109                        })
110                        .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
111                );
112
113            #[cfg(unix)]
114            axum::Server::from_tcp(listener)?
115                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
116                .with_graceful_shutdown(async move {
117                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
118                        .expect("failed to listen for interrupt signal");
119                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
120                        .expect("failed to listen for interrupt signal");
121                    let sigterm = sigterm.recv();
122                    let sigint = sigint.recv();
123                    futures::pin_mut!(sigterm, sigint);
124                    futures::future::select(sigterm, sigint).await;
125                    tracing::info!("Received interrupt signal");
126
127                    if let Some(rpc_server) = rpc_server {
128                        rpc_server.teardown();
129                    }
130                })
131                .await?;
132
133            // todo("windows")
134            #[cfg(windows)]
135            unimplemented!();
136        }
137        _ => {
138            Err(anyhow!(
139                "usage: collab <version | migrate | serve [api|collab]>"
140            ))?;
141        }
142    }
143    Ok(())
144}
145
146async fn run_migrations() -> Result<()> {
147    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
148    let db_options = db::ConnectOptions::new(config.database_url.clone());
149    let db = Database::new(db_options, Executor::Production).await?;
150
151    let migrations_path = config
152        .migrations_path
153        .as_deref()
154        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
155
156    let migrations = db.migrate(&migrations_path, false).await?;
157    for (migration, duration) in migrations {
158        log::info!(
159            "Migrated {} {} {:?}",
160            migration.version,
161            migration.description,
162            duration
163        );
164    }
165
166    return Ok(());
167}
168
169async fn handle_root() -> String {
170    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
171}
172
173async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
174    state.db.get_all_users(0, 1).await?;
175    Ok("ok".to_string())
176}
177
178pub fn init_tracing(config: &Config) -> Option<()> {
179    use std::str::FromStr;
180    use tracing_subscriber::layer::SubscriberExt;
181    let rust_log = config.rust_log.clone()?;
182
183    LogTracer::init().log_err()?;
184
185    let subscriber = tracing_subscriber::Registry::default()
186        .with(if config.log_json.unwrap_or(false) {
187            Box::new(
188                tracing_subscriber::fmt::layer()
189                    .fmt_fields(JsonFields::default())
190                    .event_format(
191                        tracing_subscriber::fmt::format()
192                            .json()
193                            .flatten_event(true)
194                            .with_span_list(true),
195                    ),
196            ) as Box<dyn Layer<_> + Send + Sync>
197        } else {
198            Box::new(
199                tracing_subscriber::fmt::layer()
200                    .event_format(tracing_subscriber::fmt::format().pretty()),
201            )
202        })
203        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
204
205    tracing::subscriber::set_global_default(subscriber).unwrap();
206
207    None
208}