main.rs

  1use anyhow::anyhow;
  2use axum::{extract::MatchedPath, http::Request, routing::get, Extension, Router};
  3use collab::{
  4    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
  5    Config, MigrateConfig, Result,
  6};
  7use db::Database;
  8use std::{
  9    env::args,
 10    net::{SocketAddr, TcpListener},
 11    path::Path,
 12    sync::Arc,
 13};
 14#[cfg(unix)]
 15use tokio::signal::unix::SignalKind;
 16use tower_http::trace::{self, TraceLayer};
 17use tracing::Level;
 18use tracing_subscriber::{
 19    filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
 20};
 21use util::ResultExt;
 22
 23const VERSION: &str = env!("CARGO_PKG_VERSION");
 24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 25
 26#[tokio::main]
 27async fn main() -> Result<()> {
 28    if let Err(error) = env::load_dotenv() {
 29        eprintln!(
 30            "error loading .env.toml (this is expected in production): {}",
 31            error
 32        );
 33    }
 34
 35    let mut args = args().skip(1);
 36    match args.next().as_deref() {
 37        Some("version") => {
 38            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 39        }
 40        Some("migrate") => {
 41            run_migrations().await?;
 42        }
 43        Some("serve") => {
 44            let (is_api, is_collab) = if let Some(next) = args.next() {
 45                (next == "api", next == "collab")
 46            } else {
 47                (true, true)
 48            };
 49            if !is_api && !is_collab {
 50                Err(anyhow!(
 51                    "usage: collab <version | migrate | serve [api|collab]>"
 52                ))?;
 53            }
 54
 55            let config = envy::from_env::<Config>().expect("error loading config");
 56            init_tracing(&config);
 57
 58            run_migrations().await?;
 59
 60            let state = AppState::new(config).await?;
 61
 62            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 63                .expect("failed to bind TCP listener");
 64
 65            let rpc_server = if is_collab {
 66                let epoch = state
 67                    .db
 68                    .create_server(&state.config.zed_environment)
 69                    .await?;
 70                let rpc_server =
 71                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 72                rpc_server.start().await?;
 73
 74                Some(rpc_server)
 75            } else {
 76                None
 77            };
 78
 79            if is_api {
 80                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
 81            }
 82
 83            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
 84            if let Some(rpc_server) = rpc_server.clone() {
 85                app = app.merge(collab::rpc::routes(rpc_server))
 86            }
 87            app = app
 88                .merge(
 89                    Router::new()
 90                        .route("/", get(handle_root))
 91                        .route("/healthz", get(handle_liveness_probe))
 92                        .merge(collab::api::extensions::router())
 93                        .merge(collab::api::events::router())
 94                        .layer(Extension(state.clone())),
 95                )
 96                .layer(
 97                    TraceLayer::new_for_http()
 98                        .make_span_with(|request: &Request<_>| {
 99                            let matched_path = request
100                                .extensions()
101                                .get::<MatchedPath>()
102                                .map(MatchedPath::as_str);
103
104                            tracing::info_span!(
105                                "http_request",
106                                method = ?request.method(),
107                                matched_path,
108                            )
109                        })
110                        .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
111                );
112
113            #[cfg(unix)]
114            axum::Server::from_tcp(listener)
115                .map_err(|e| anyhow!(e))?
116                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
117                .with_graceful_shutdown(async move {
118                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
119                        .expect("failed to listen for interrupt signal");
120                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
121                        .expect("failed to listen for interrupt signal");
122                    let sigterm = sigterm.recv();
123                    let sigint = sigint.recv();
124                    futures::pin_mut!(sigterm, sigint);
125                    futures::future::select(sigterm, sigint).await;
126                    tracing::info!("Received interrupt signal");
127
128                    if let Some(rpc_server) = rpc_server {
129                        rpc_server.teardown();
130                    }
131                })
132                .await
133                .map_err(|e| anyhow!(e))?;
134
135            // todo(windows)
136            #[cfg(windows)]
137            unimplemented!();
138        }
139        _ => {
140            Err(anyhow!(
141                "usage: collab <version | migrate | serve [api|collab]>"
142            ))?;
143        }
144    }
145    Ok(())
146}
147
148async fn run_migrations() -> Result<()> {
149    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
150    let db_options = db::ConnectOptions::new(config.database_url.clone());
151    let db = Database::new(db_options, Executor::Production).await?;
152
153    let migrations_path = config
154        .migrations_path
155        .as_deref()
156        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
157
158    let migrations = db.migrate(&migrations_path, false).await?;
159    for (migration, duration) in migrations {
160        log::info!(
161            "Migrated {} {} {:?}",
162            migration.version,
163            migration.description,
164            duration
165        );
166    }
167
168    return Ok(());
169}
170
171async fn handle_root() -> String {
172    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
173}
174
175async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
176    state.db.get_all_users(0, 1).await?;
177    Ok("ok".to_string())
178}
179
180pub fn init_tracing(config: &Config) -> Option<()> {
181    use std::str::FromStr;
182    use tracing_subscriber::layer::SubscriberExt;
183
184    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
185
186    tracing_subscriber::registry()
187        .with(if config.log_json.unwrap_or(false) {
188            Box::new(
189                tracing_subscriber::fmt::layer()
190                    .fmt_fields(JsonFields::default())
191                    .event_format(
192                        tracing_subscriber::fmt::format()
193                            .json()
194                            .flatten_event(true)
195                            .with_span_list(true),
196                    )
197                    .with_filter(filter),
198            ) as Box<dyn Layer<_> + Send + Sync>
199        } else {
200            Box::new(
201                tracing_subscriber::fmt::layer()
202                    .event_format(tracing_subscriber::fmt::format().pretty())
203                    .with_filter(filter),
204            )
205        })
206        .init();
207
208    None
209}